import os
import sys
import shutil
import glob
import csv
import re
import numpy as np
from numpy import sqrt, exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
import tklib.tkre as tkre
from tklib.tkutils import save_csv
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tksci import Reduce01, Round
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tksci.tkconvolution import convolution, convolve_func
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType
from tklib.tkcrystal.tkvasp import tkVASP
from tklib.tksci import tkequation
import tklib.tkcsv


"""
Plot DOS calculated by VASP
"""

#================================
# global parameters
#================================
debug = 0

mode = 'DOS'

CAR_dir = '.'

# Energy to search edge energies
EF0  = 0.1
# Critera DOS value to search band edges
DOSth = 1.0e-5

# Max number of electrons occupies one state
Nemax = 1.0

# width of Gaussian functio for convolution
width = 0.0

# Plot configuration
band_marker_size       = 4
band_marker_edge_width = 0.5

Emin = -10.0  # eV
Emax =  10.0  # eV

# Occupancy threshold to separate HOMO and LUMO
occ_th = 0.5

#colors = ['#000000', '#ff0000', '#00aa00', '#0000ff', '#aaaa00', '#ff00ff', '#00ffff', '#aa0000', '#00aa00', '#0000aa']
#colors = ['k', 'r', 'g', 'b', 'y', 'm', 'c']
figsize = (8, 6)
fontsize        = 16
labelfontsize   = 12
legend_fontsize = 12

save_figure = 1
plot_figure = 1
plot_figure_Ne = 1

#=============================
# Treat argments
#=============================
def usage():
    print("")
    print("Usage:")
    print("  (a)  python {} mode CAR_dir Emin Emax Gaussian_width, save_figure plot_figure plot_figure_Ne".format(sys.argv[0]))
    print("     ex: python {} {} {} {} {} {} {} {} {}"
                .format(sys.argv[0], mode, CAR_dir, Emin, Emax, width, save_figure, plot_figure, plot_figure_Ne))

def updatevars():
    global mode
    global CAR_dir
    global Emin, Emax, width, occ_th
    global save_figure, plot_figure, plot_figure_Ne

    mode    = getarg     (1, mode)
    CAR_dir = getarg     (2, CAR_dir)
    Emin    = getfloatarg(3, Emin)
    Emax    = getfloatarg(4, Emax)
    occ_th  = getfloatarg(5, occ_th)
    width   = getfloatarg(6, width)
    save_figure    = getintarg(7, save_figure)
    plot_figure    = getintarg(8, plot_figure)
    plot_figure_Ne = getintarg(9, plot_figure_Ne)

    if mode == 'DOS':
        pass
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


#def save_csv(path, headerlist, datalist, is_print = 0):

def plot_dos(mode, CAR_path, Emin, Emax):
    vasp = tkVASP()
    
    base_path = vasp.getdir(CAR_path)

    print("")
    INCAR_path   = vasp.get_INCAR(base_path)
    POSCAR_path  = vasp.get_POSCAR(base_path)
    CONTCAR_path = vasp.get_CONTCAR(base_path)
    OUTCAR_path  = vasp.get_OUTCAR(base_path)
    EIGENVAL_path = vasp.get_VASPPath(base_path, 'EIGENVAL')
    DOSCAR_path   = vasp.get_VASPPath(base_path, 'DOSCAR')
    print("CAR dir(ideal) : ", CAR_path)
    print("  INCAR   : ", INCAR_path)
    print("  POSCAR  : ", POSCAR_path)
    print("  CONTCAR : ", CONTCAR_path)
    print("  OUTCAR  : ", OUTCAR_path)
    print("  EIGENVAL: ", EIGENVAL_path)
    print("  DOSCAR  : ", DOSCAR_path)
    print("")
    print(f"Plot E range: {Emin} - {Emax} eV")
    print(f"Occupancy threshold to seprate HOMO and LUMO: {occ_th}")
    print(f"Gaussian width for convolution: {width} eV")
    print(f"save_figure   : {save_figure}")
    print(f"plot_figure   : {plot_figure}")
    print(f"plot_figure_Ne: {plot_figure_Ne}")
    
    print("")
    print("*** Read crystal structure from [{}]".format(POSCAR_path))
    cry1 = vasp.read_poscar(POSCAR_path)
    if cry1 is None:
        terminate("Error: Can not read [{}]".format(POSCAR_path), usage = usage)

    a1, b1, c1, alpha1, beta1, gamm1 = cry1.LatticeParameters()
    Vcell1        = cry1.Volume()
    cry1.PrintInf("cell")

    print("")
    incarinf = vasp.read_incar_inf(INCAR_path)

    print("")
    outcarinf = vasp.read_outcar_inf(OUTCAR_path)
    EF    = outcarinf["EF"]
    ISPIN = outcarinf["ISPIN"]
    print("Information in OUTCAR:")
    print("  ISPIN: ", ISPIN)
    print("  EF = {} eV corrected to 0".format(EF))

    doscarinf = vasp.read_doscar(DOSCAR_path, unit = '/cm3')
    nE     = doscarinf["nE"]
    E      = doscarinf["E"]
    if nE == 0:
        Edmin = 0.0
        Edmax = 0.0
    else:
        Edmin  = min(E)
        Edmax  = max(E)
    tDOSup = doscarinf["TotalDOSup"]
    Neup   = doscarinf["Neup"]
    tDOSdn = doscarinf["TotalDOSdn"]
    Nedn   = doscarinf["Nedn"]

    if nE > 0:
        print("")
        print("DOS E range: {:10.6g} - {:10.6g} eV, {} points".format(Edmin, Edmax, nE))

    eigenvalinf = vasp.read_eigenval(EIGENVAL_path, EF = EF)
    nk      = eigenvalinf["nk"]
    nLevels = eigenvalinf["nLevels"]
    bandedgeinf  = vasp.find_band_edges_from_eigenval(EF0 = EF0, eigenvalinf = eigenvalinf, ISPIN = ISPIN, occ_th = occ_th)
    bandedgeinf2 = vasp.gbandedges(CAR_path, OUTCAR_path, EIGENVAL_path, EF)

    print("k points in EIGENVAL:")
    print("nk=", nk)
    print("nLevels=", nLevels)
    for i in range(nk):
        el = eigenvalinf['EList'][i]
        kx, ky, kz, wk, dk, ktot, Eups, occups, Edns, occdns = el
        print("  ({:8.4f} {:8.4f} {:8.4f}) w={:8.4g}  {:8.4f} {:12.4f}".format(kx, ky, kz, wk, dk, ktot))

    EV1   = bandedgeinf["EV"]
    EC1   = bandedgeinf["EC"]
    Eg1   = bandedgeinf["Eg"]
    EV    = bandedgeinf2["EVBM"] - EF
    EC    = bandedgeinf2["ECBM"] - EF
    Eg    = bandedgeinf2["Eg"]
    EHOMO = bandedgeinf["EHOMO"]
    ELUMO = bandedgeinf["ELUMO"]
    print(f"Band edge from EIGENVAL:")
    print(f"EF0={EF0:10.6f} eV")
    print(f"  find_band_edges: EV={EV:10.6f}  EC={EC:10.6f}  Eg={Eg:10.6f} eV")
    print(f"             HOMO:{EHOMO:10.6f} eV")
    print(f"             LUMO:{ELUMO:10.6f} eV")
    print(f"  gbandedges()   : EV={EV:10.6f}  EC={EC:10.6f}  Eg={Eg:10.6f} eV")


#=============================
# Plot graphs
#=============================
    print("")

    fig = plt.figure(figsize = figsize)
    if plot_figure_Ne:
        ax1    = fig.add_subplot(2, 1, 1)
        ax2    = fig.add_subplot(2, 1, 2)
    else:
        ax1    = fig.add_subplot(1, 1, 1)

    tDOSup_s = convolution(E, tDOSup, width, func_type = 'gauss')
    ax1.plot(E, tDOSup_s, label = 'up', linestyle = '-', linewidth = 0.5, color = 'black')
    if ISPIN == 2:
        tDOSdn_s = convolution(E, tDOSdn, width, func_type = 'gauss')
        ax1.plot(E, tDOSdn_s, label = 'dn', linestyle = '-', linewidth = 0.5, color = 'red')
    ylim = ax1.get_ylim()
    ax1.plot([EV, EV],       ylim, label = '$E_V$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax1.plot([EC, EC],       ylim, label = '$E_C$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax1.plot([EF0, EF0],     ylim, label = '$E_{F0}$',   linestyle = 'dashed', linewidth = 0.5, color = 'blue')
    ax1.plot([EHOMO, EHOMO], ylim, label = '$E_{HOMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
    ax1.plot([ELUMO, ELUMO], ylim, label = '$E_{LUMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
    ax1.set_xlabel("$E$ (eV)",    fontsize = fontsize)
    ax1.set_ylabel("DOS (states/cm$^3$/eV)", fontsize = fontsize)
    ax1.set_xlim([Emin, Emax])
    ax1.legend(fontsize = legend_fontsize)
    ax1.tick_params(labelsize = fontsize)

    if plot_figure_Ne:
        ax2.plot(E, Neup, label = 'up', linestyle = '-', linewidth = 0.5, color = 'black')
        if ISPIN == 2:
            ax2.plot(E, Nedn, label = 'dn', linestyle = '-', linewidth = 0.5, color = 'red')
        ylim = ax2.get_ylim()
        ax2.plot([EV, EV],   ylim,     label = '$E_V$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
        ax2.plot([EC, EC],   ylim,     label = '$E_C$',      linestyle = 'dashed', linewidth = 0.5, color = 'red')
        ax2.plot([EF0, EF0], ylim,     label = '$E_{F0}$',   linestyle = 'dashed', linewidth = 0.5, color = 'blue')
        ax2.plot([EHOMO, EHOMO], ylim, label = '$E_{HOMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
        ax2.plot([ELUMO, ELUMO], ylim, label = '$E_{LUMO}$', linestyle = 'dashed', linewidth = 0.5, color = 'green')
        ax2.set_xlabel("$E$ (eV)",    fontsize = fontsize)
        ax2.set_ylabel("$N_e$ (states/cm$^3$)", fontsize = fontsize)
        ax2.set_xlim([Emin, Emax])
        ax2.legend(fontsize = legend_fontsize)
        ax2.tick_params(labelsize = fontsize)

# Rearange the graph axes so that they are not overlapped
    plt.tight_layout()

    if save_figure:
        outfig_path = 'dos.png'
        print("")
        print(f"Save DOS to figure file [{outfig_path}]")
        plt.savefig(outfig_path, dpi=300, transparent=True)
        if not plot_figure:
            input("Press ENTER to terminate>>")

    use_pause = True
    if plot_figure:
        if use_pause:
            plt.pause(0.1)
            print("")
            input("Press ENTER to exit>>")
        else:
            plt.show()
            print("")
            print("Close graph window to terminate")

    terminate("", usage = usage)


def main():
    global mode
    global cifpath, poscarpath
    
    updatevars()

    print("")
    print("=============== Plot DOS calculated by VASP ============")
    print("")
    print("mode: ", mode)

    if mode == 'DOS':
        plot_dos(mode, CAR_dir, Emin, Emax)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


if __name__ == "__main__":
    main()
