import os
import sys
import numpy as np
from numpy import sqrt, exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from matplotlib import pyplot as plt


from tklib.tkutils import colors, minmax_xy, get_last_directory, split_file_path
from tklib.tkapplication import tkApplication
from tklib.tkfile import tkFile
import tklib.tkre as tkre
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tkconvolution import convolution, convolve_func
from tklib.tkcrystal.tkvasp import tkVASP
import tklib.tkcsv

import filter.vasp2dos as v2d


"""
Plot DOS calculated by VASP
"""


def initialize(app, cparams):
    cparams.debug = 0

    cparams.mode = 'DOS'

    cparams.infile = '.'

# Energy to search edge energies
    cparams.EF0  = 0.1
# Critera DOS value to search band edges
    cparams.DOSth = 1.0e-5

# Max number of electrons occupies one state
    cparams.Nemax = 1.0

# width of Gaussian functio for convolution
    cparams.width = 0.0

# Plot configuration
    cparams.band_marker_size       = 4
    cparams.band_marker_edge_width = 0.5

    cparams.Emin = -10.0  # eV
    cparams.Emax =  10.0  # eV

# Occupancy threshold to separate HOMO and LUMO
    cparams.occ_th = 0.5

#colors = ['#000000', '#ff0000', '#00aa00', '#0000ff', '#aaaa00', '#ff00ff', '#00ffff', '#aa0000', '#00aa00', '#0000aa']
#colors = ['k', 'r', 'g', 'b', 'y', 'm', 'c']
    cparams.figsize = (8, 6)
    cparams.fontsize        = 16
    cparams.labelfontsize   = 12
    cparams.legend_fontsize = 8

    cparams.save_figure = 1
    cparams.plot_figure = 1
    cparams.plot_figure_Ne = 1

#=============================
# Treat argments
#=============================
def usage(app = None, cparams = None):
    print("")
    print("Usage:")
    print("  (a)  python {} mode infile Emin Emax occ_th Gaussian_width, save_figure plot_figure plot_figure_Ne".format(sys.argv[0]))
    print("     ex: python {} {} {} {} {} {} {} {} {} {}"
                .format(sys.argv[0], cparams.mode, cparams.infile, cparams.Emin, cparams.Emax, cparams.occ_th, cparams.width, 
                        cparams.save_figure, cparams.plot_figure, cparams.plot_figure_Ne))

def updatevars(app, cparams):
    cparams.mode    = getarg     (1, cparams.mode)
    cparams.infile  = getarg     (2, cparams.infile)
    cparams.Emin    = getfloatarg(3, cparams.Emin)
    cparams.Emax    = getfloatarg(4, cparams.Emax)
    cparams.occ_th  = getfloatarg(5, cparams.occ_th)
    cparams.width   = getfloatarg(6, cparams.width)
    cparams.save_figure    = getintarg(7, cparams.save_figure)
    cparams.plot_figure    = getintarg(8, cparams.plot_figure)
    cparams.plot_figure_Ne = getintarg(9, cparams.plot_figure_Ne)

    if cparams.mode == 'DOS':
        pass
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = lambda: usage(app, cparams))


def plot_dos(app, cparams):
    cparams.outfile = v2d.get_output_path(cparams.infile)

    print("")
    print(f"Plot E range: {cparams.Emin} - {cparams.Emax} eV")
    print(f"Occupancy threshold to seprate HOMO and LUMO: {cparams.occ_th}")
    print(f"Gaussian width for convolution: {cparams.width} eV")
    print(f"save_figure   : {cparams.save_figure}")
    print(f"plot_figure   : {cparams.plot_figure}")
    print(f"plot_figure_Ne: {cparams.plot_figure_Ne}")
    print("Output file:")
    print(f"  xlsx file: {cparams.outfile}")

# Energy to search edge energies
    cparams.EF0  = 0.1
# Critera DOS value to search band edges
    cparams.DOSth = 1.0e-5
# Max number of electrons occupies one state
    cparams.Nemax = 1.0
    cparams.xmin = cparams.Emin
    cparams.xmax = cparams.Emax

    file_type = v2d.check_file_type(cparams.infile, print_level = 1)
    if file_type is None:
        app.terminate(f"Error: [{cparams.infile}] invalid file type", usage = lambda: usage(app, cparams))
    if 'Error' in file_type:
        app.terminate(file_type, usage = lambda: usage(app, cparams))

    inf = v2d.read_data(cparams.infile, cparams = cparams, print_level = 2)
    v2d.print_data(inf)
    inf = v2d.convert(inf, cparams = cparams, print_level = 1)
    v2d.save_data([cparams.outfile], inf, cparams = cparams, print_level = 1)
#    v2d.plot_data(inf, cparams = cparams)

    cry = inf["cry"]
    doscarinf   = inf["doscarinf"]
    eigenvalinf = inf["eigenvalinf"]
    bandedgeinf = inf["bandedgeinf"]
    
    ISPIN = doscarinf["ISPIN"]

    data_list = inf["data_list"][0]
    E     = data_list[0]
    nE    = inf["ndata"]
    Edmin = inf["xmin"]
    Edmax = inf["xmax"]
    data_list = inf["data_list"][0]
    labels    = inf["labels"][0]
    Neup   = doscarinf["Neup"]
    Nedn   = doscarinf["Nedn"]

    if nE > 0:
        print("")
        print("DOS E range: {:10.6g} - {:10.6g} eV, {} points".format(Edmin, Edmax, nE))

    nk      = eigenvalinf["nk"]
    nLevels = eigenvalinf["nLevels"]
    print("k points in EIGENVAL:")
    print("nk=", nk)
    print("nLevels=", nLevels)
    for i in range(nk):
        el = eigenvalinf['EList'][i]
        kx, ky, kz, wk, dk, ktot, Eups, occups, Edns, occdns = el
        print("  ({:8.4f} {:8.4f} {:8.4f}) w={:8.4g}  {:8.4f} {:12.4f}".format(kx, ky, kz, wk, dk, ktot))

    EV = bandedgeinf["EV"]
    EC = bandedgeinf["EC"]
    Eg = bandedgeinf["Eg"]
    print(f"Band edge from EIGENVAL:")
    print(f"EF0={cparams.EF0:10.6f} eV")
    print(f"  band_edges: EV={EV:10.6f}  EC={EC:10.6f}  Eg={Eg:10.6f} eV")
    print(f"              EHOMO={bandedgeinf['EHOMO']:10.6f}  ELUMO={bandedgeinf['ELUMO']:10.6f} eV")


#=============================
# Plot graphs
#=============================
    print("")

    fig = plt.figure(figsize = cparams.figsize)
    if cparams.plot_figure_Ne:
        ax1    = fig.add_subplot(2, 1, 1)
        ax2    = fig.add_subplot(2, 1, 2)
    else:
        ax1    = fig.add_subplot(1, 1, 1)

    dirname, lastdir, filebody, ext = split_file_path(cparams.infile)
#    lastdir = get_last_directory(cparams.infile)
    plt.title(lastdir)

    ncolors = len(colors)
    ymin =  1.0e300
    ymax = -1.0e300
    for i in range(len(data_list)):
        color = colors[i % ncolors]
        ax1.plot(E, data_list[i], label = labels[i], linestyle = '-', linewidth = 0.5, color = color)
        _min, _max = minmax_xy(x = None, y = data_list[i], x0 = E[0], xstep = E[1] - E[0], xmin = cparams.Emin, xmax = cparams.Emax)
        ymin = min([ymin, _min])
        ymax = max([ymax, _max])
    ylim = ax1.get_ylim()
    ax1.plot([EV, EV],       ylim, label = '$E_V$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax1.plot([EC, EC],       ylim, label = '$E_C$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax1.plot([cparams.EF0, cparams.EF0],     ylim, label = '$E_{F0}$',   linestyle = 'dashed', linewidth = 0.5, color = 'blue')
#    ax1.plot([EHOMO, EHOMO], ylim, label = '$E_{HOMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
#    ax1.plot([ELUMO, ELUMO], ylim, label = '$E_{LUMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
    ax1.set_xlabel("$E$ (eV)",               fontsize = cparams.fontsize)
    ax1.set_ylabel("DOS (states/cm$^3$/eV)", fontsize = cparams.fontsize)
    ax1.set_xlim([cparams.Emin, cparams.Emax])
    ax1.set_ylim([ymin, ymax])
    ax1.legend(fontsize = cparams.legend_fontsize)
    ax1.tick_params(labelsize = cparams.fontsize)

    if cparams.plot_figure_Ne:
        ax2.plot(E, Neup, label = 'Ne up', linestyle = '-', linewidth = 0.5, color = 'black')
        ymin, ymax = minmax_xy(x = None, y = Neup, x0 = E[0], xstep = E[1] - E[0], xmin = cparams.Emin, xmax = cparams.Emax)
        if ISPIN == 2:
            ax2.plot(E, Nedn, label = 'Ne dn', linestyle = '-', linewidth = 0.5, color = 'red')
            _min, _max = minmax_xy(x = None, y = Nedn, x0 = E[0], xstep = E[1] - E[0], xmin = cparams.Emin, xmax = cparams.Emax)
            ymin = min([ymin, _min])
            ymax = max([ymax, _max])
        ylim = ax2.get_ylim()
        ax2.plot([EV, EV],   ylim,     label = '$E_V$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
        ax2.plot([EC, EC],   ylim,     label = '$E_C$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
        ax2.plot([cparams.EF0, cparams.EF0], ylim,     label = '$E_{F0}$',   linestyle = 'dashed', linewidth = 0.5, color = 'blue')
#        ax2.plot([EHOMO, EHOMO], ylim, label = '$E_{HOMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
#        ax2.plot([ELUMO, ELUMO], ylim, label = '$E_{LUMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
        ax2.set_xlabel("$E$ (eV)",              fontsize = cparams.fontsize)
        ax2.set_ylabel("$N_e$ (states/cm$^3$)", fontsize = cparams.fontsize)
        ax2.set_xlim([cparams.Emin, cparams.Emax])
        ax2.set_ylim([ymin, ymax])
        ax2.legend(fontsize = cparams.legend_fontsize)
        ax2.tick_params(labelsize = cparams.fontsize)

# Rearange the graph axes so that they are not overlapped
    plt.tight_layout()

    if cparams.save_figure:
        outfig_path = 'dos.png'
        print("")
        print(f"Save DOS to figure file [{outfig_path}]")
        plt.savefig(outfig_path, dpi=300, transparent=True)
        if not cparams.plot_figure:
            input("Press ENTER to terminate>>")

    use_pause = True
    if cparams.plot_figure:
        if use_pause:
            plt.pause(0.1)
            print("")
            input("Press ENTER to exit>>")
        else:
            plt.show()
            print("")
            print("Close graph window to terminate")

    terminate("", usage = lambda: usage(app, cparams))


def main():
    app     = tkApplication()
    cparams = app.get_params()

    initialize(app, cparams)
    updatevars(app, cparams)

#    logfile = os.path.join(base_path, "vasp2dos-out.txt")
    logfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    print("")
    print("=============== Plot DOS calculated by VASP ============")
    print("")
    print("mode: ", cparams.mode)

    if cparams.mode == 'DOS':
        plot_dos(app, cparams)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = lambda: usage(app, cparams))


if __name__ == "__main__":
    main()
