import os
import sys
import shutil
import glob
import csv
import re
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
import tklib.tkre as tkre
#from tklib.tkutils import save_csv
from tklib.tkvariousdata import tkVariousData
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tksci import Reduce01, Round
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType
from tklib.tkcrystal.tkvasp import tkVASP
from tklib.tkcrystal.tkbandstructure import plot_band
from tklib.tksci import tkequation
import tklib.tkcsv


"""
Plot band structure calcaulcated by VASP
"""

#================================
# global parameters
#================================
debug = 0

# mode: 'band', 'bandline', 'bandocc'
validmodes = ['band', 'bandline', 'bandocc']
mode = 'band'
#mode = 'bandocc'

CAR_dir = '.'

# Energy to search edge energies
EF0 = 0.1

# Max number of electrons occupies one state
Nemax = 1.0

# Occupancy threshold to separate HOMO and LUMO
occ_th = 0.5

# Output CSV path
#band_path = 'band.csv'
band_path = 'band.xlsx'

save_figure = 1
plot_figure = 1

#===================================
# figure configuration
#===================================
band_marker_size       = 4
band_marker_edge_width = 0.5

Emin = -10.0  # eV
Emax =  10.0  # eV

figsize = (6, 8)
#figsize = (6, 4)

#colors = ['#000000', '#ff0000', '#00aa00', '#0000ff', '#aaaa00', '#ff00ff', '#00ffff', '#aa0000', '#00aa00', '#0000aa']
#colors = ['k', 'r', 'g', 'b', 'y', 'm', 'c']
fontsize        = 16
labelfontsize   = 12
legend_fontsize = 8

#=============================
# Treat argments
#=============================
def usage():
    global mode
    
    if mode not in validmodes:
        mode = validmodes[0]

    print("")
    print("Usage:")
    print("  (a)  python {} mode CAR_dir EF0 Emin, Emax".format(sys.argv[0]))
    print("         mode: {}".format(validmodes))
    print("     ex: python {} {} {} {} {} {}".format(sys.argv[0], mode, CAR_dir, EF0, Emin, Emax))

def updatevars():
    global mode
    global CAR_dir
    global EF0, occ_th
    global Emin, Emax
    global save_figure, plot_figure

    mode    = getarg     (1, mode)
    CAR_dir = getarg     (2, CAR_dir)
    EF0     = getfloatarg(3, EF0)
    Emin    = getfloatarg(4, Emin)
    Emax    = getfloatarg(5, Emax)
    occ_th  = getfloatarg(6, occ_th)
    save_figure = getintarg(7, save_figure)
    plot_figure = getintarg(8, plot_figure)
    if mode == 'band' or mode == 'bandline' or mode == 'bandocc':
        pass
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


def save_bandstructure(band_path, ISPIN, kvec, Eup, Edn):
    nk = len(kvec)
    nL = len(Eup[0])

    print("")
    print("Save band structure data to [{}]".format(band_path))

    labels = ["kx", "ky", "kz", "dk", "ktot"]
    labels.extend([f'Eup({iL})' for iL in range(nL)])
    kv0 = [kvec[ik][0] for ik in range(nk)]
    kv1 = [kvec[ik][1] for ik in range(nk)]
    kv2 = [kvec[ik][2] for ik in range(nk)]
    kv3 = [kvec[ik][3] for ik in range(nk)]
    kv4 = [kvec[ik][4] for ik in range(nk)]
    EupT = np.array(Eup).T
    data_list = [kv0, kv1, kv2, kv3, kv4, *EupT]

    if len(Edn) > 0:
        labels.extend([f'Edn({iL})' for iL in range(nL)])
        EdnT = np.array(Edn).T
        data_list.extend(*EdnT)
    tkVariousData().to_excel(band_path, labels = labels, data_list = data_list)
    return True

    out = tkFile(band_path, 'w')
    if not out:
        print("Erorr in save_bandstructure_csv: Can not read [{}]".format(band_path))
        return 0

#    print("Eup=", Eup)
    out.Write("kx,ky,kz,dk,ktot")
    for iL in range(nL):
        out.Write(",Eup({})".format(iL))
    if ISPIN == 2:
        for iL in range(nL):
            out.Write(",Edn({})".format(iL))
    out.Write("\n")
    
    for ik in range(nk):
        kv = kvec[ik]
        Eu = Eup[ik]
        out.Write("{},{},{},{},{}".format(kv[0], kv[1], kv[2], kv[3], kv[4]))
        for iL in range(nL):
            out.Write(",{}".format(Eu[iL]))
        if ISPIN == 2:
            Ed = Edn[ik]
            for iL in range(nL):
                out.Write(",{}".format(Ed[iL]))
        out.Write("\n")

    out.Close()

    return 1

def plot_band_structure(mode, CAR_path, Emin, Emax):
    vasp = tkVASP()
    
    base_path = vasp.getdir(CAR_path)

    print("")
    INCAR_path   = vasp.get_INCAR(base_path)
    POSCAR_path  = vasp.get_POSCAR(base_path)
    KPOINTS_path  = vasp.get_VASPPath(base_path, 'KPOINTS')
    CONTCAR_path = vasp.get_CONTCAR(base_path)
    OUTCAR_path  = vasp.get_OUTCAR(base_path)
    EIGENVAL_path = vasp.get_VASPPath(base_path, 'EIGENVAL')
    DOSCAR_path   = vasp.get_VASPPath(base_path, 'DOSCAR')
    print("CAR dir(ideal) : ", CAR_path)
    print("  INCAR   : ", INCAR_path)
    print("  POSCAR  : ", POSCAR_path)
    print("  KPOINTS : ", KPOINTS_path)
    print("  CONTCAR : ", CONTCAR_path)
    print("  OUTCAR  : ", OUTCAR_path)
    print("  EIGENVAL: ", EIGENVAL_path)
    print("  DOSCAR  : ", DOSCAR_path)
    print("")
    print(f"Occupancy threshold to seprate HOMO and LUMO: {occ_th}")
    print(f"save_figure: {save_figure}")
    print(f"plot_figure: {plot_figure}")

    print("")
    print("EF0: {} eV".format(EF0))

    print("")
    print("*** Read crystal structure from [{}]".format(POSCAR_path))
    cry1 = vasp.read_poscar(POSCAR_path)
    if cry1 is None:
        terminate("Error: Can not read [{}]".format(POSCAR_path), usage = usage)

    a1, b1, c1, alpha1, beta1, gamm1 = cry1.LatticeParameters()
    cry1.PrintInf("cell")

    print("")
    incarinf = vasp.read_incar_inf(INCAR_path)

    print("")
    outcarinf = vasp.read_outcar_inf(OUTCAR_path)
    EF    = outcarinf.get("EF", None)
    ISPIN = outcarinf.get("ISPIN", 1)
    IsHF      = outcarinf.get("IsHF", 0)
    IsMETAGGA = outcarinf.get("IsMETAGGA", 0)
    print("Information in OUTCAR:")
    print(f"  ISPIN    : {ISPIN}")
    print(f"  IsHF     : {IsHF}")
    print(f"  IsMETAGGA: {IsMETAGGA}")
    print(f"  EF = {EF} eV corrected to 0")
    if IsMETAGGA:
        IsHF = 1

    kpoints_inf = vasp.read_kpoints(KPOINTS_path)
#    print("kpoints: ", kpoints_inf)
#    print("kpoints: ", kpoints_inf['kpoints'])
    nk = kpoints_inf['nk']
    kpoints = kpoints_inf['kpoints']
    kptdic  = kpoints_inf['kpointsdic']

    print("")
    print("Special k points in KPOINTS:")
    if IsHF == 0:
        kp = kptdic[0]
        kx0, ky0, kz0, kname0 = kp["kx0"], kp["ky0"], kp["kz0"], kp["kname0_conv"]
        kx1, ky1, kz1, kname1 = kp["kx1"], kp["ky1"], kp["kz1"], kp["kname1_conv"]
        w, dk, ktot = kp["kw"], 0.0, 0.0
        if kname0 == '':
            kname0 = 'no_name'
        print("  {:10} ({:8.4f} {:8.4f} {:8.4f}) w={:.3f}  {:8.4f} {:12.4f}".format(kname0, kx0,  ky0, kz0, w, dk, ktot))
    for i in range(len(kptdic)):
        kp = kptdic[i]
        kx0, ky0, kz0, kname0 = kp["kx0"], kp["ky0"], kp["kz0"], kp["kname0_conv"]
        kx1, ky1, kz1, kname1 = kp["kx1"], kp["ky1"], kp["kz1"], kp["kname1_conv"]
        w, dk, ktot = kp["kw"], kp["dk"], kp["ktot"]
        if kname1 == '':
            kname1 = 'no_name'
        print("  {:10} ({:8.4f} {:8.4f} {:8.4f}) w={:.3f}  {:8.4f} {:12.4f}".format(kname1, kx1,  ky1, kz1, w, dk, ktot))

    print("")
    print("k points in KPOINTS:")
    if IsHF == 0 and IsMETAGGA == 0:
        kp = kpoints[0]
        kx0, ky0, kz0, kname0 = kp["kx0"], kp["ky0"], kp["kz0"], kp["kname0_conv"]
        kx1, ky1, kz1, kname1 = kp["kx1"], kp["ky1"], kp["kz1"], kp["kname1_conv"]
        w, dk, ktot = kp["kw"], 0.0, 0.0
        print("  {:10} ({:8.4f} {:8.4f} {:8.4f}) w={:.3f}  {:8.4f} {:12.4f}".format(kname0, kx0, ky0, kz0, w, dk, ktot))
    for i in range(nk):
        kp = kpoints[i]
        kx0, ky0, kz0, kname0 = kp["kx0"], kp["ky0"], kp["kz0"], kp["kname0_conv"]
        kx1, ky1, kz1, kname1 = kp["kx1"], kp["ky1"], kp["kz1"], kp["kname1_conv"]
        w, dk, ktot = kp["kw"], kp["dk"], kp["ktot"]
        print("  {:10} ({:8.4f} {:8.4f} {:8.4f}) w={:.3f}  {:8.4f} {:12.4f}".format(kname1, kx1,  ky1, kz1, w, dk, ktot))

    eigenvalinf = vasp.read_eigenval(EIGENVAL_path, KPOINTS_inf = kpoints_inf, ISPIN = ISPIN, IsHF = IsHF, EF = EF)
#    print("eigenvalinf=", eigenvalinf)
    nk      = eigenvalinf["nk"]
    nLevels = eigenvalinf["nLevels"]
    EList = eigenvalinf['EList']
    nk_band = len(EList)
    print("")
    print("k points in EIGENVAL:")
    print("nk(band)=", nk_band)
    print("nLevels =", nLevels)
    for i in range(nk_band):
        el = EList[i]
        kx, ky, kz, wk, dk, ktot, Eups, occups, Edns, occdns = el
        print(f"  ({kx:8.4f} {ky:8.4f} {kz:8.4f}) w={wk:8.4g}  dk={dk:8.4f} ktot={ktot:8.4f}")

    bandedgeinf = vasp.find_band_edges_from_eigenval(EF0 = EF0, eigenvalinf = eigenvalinf, ISPIN = ISPIN, occ_th = occ_th)
    print("bandedgeinf=", bandedgeinf)
    bandedgeinf2 = vasp.gbandedges(CAR_path, OUTCAR_path, EIGENVAL_path, EF)
    EV1   = bandedgeinf["EV"]
    EC1   = bandedgeinf["EC"]
    Eg1   = bandedgeinf["Eg"]
    EV    = bandedgeinf2["EVBM"] - EF
    EC    = bandedgeinf2["ECBM"] - EF
    Eg    = bandedgeinf2["Eg"]
    EHOMO = bandedgeinf["EHOMO"]
    ELUMO = bandedgeinf["ELUMO"]
    print("")
    print(f"Band edge from EIGENVAL:")
    print(f"EF from OUTCAR: {EF:10.6f} eV")
    print(f"EF0={EF0:10.6f} eV")
    print(f"  find_band_edges: EV={EV1:10.6f}  EC={EC1:10.6f}  Eg={Eg:10.6f} eV")
    print(f"             HOMO:{EHOMO:10.6f} eV")
    print(f"             LUMO:{ELUMO:10.6f} eV")
    print(f"  gbandedges()   : EV={EV:10.6f}  EC={EC:10.6f}  Eg={Eg:10.6f} eV")
#    print(f"                         EHOMO={EHOMO+EF:10.6f}  ELUMO={ELUMO+EF:10.6f}")

#=============================
# Plot graphs
#=============================
    print("")
    print("plot band structure")

    fig = plt.figure(figsize = figsize)
    axband = fig.add_subplot(1, 1, 1)

# バンド構造をプロット
# xk: プロットするk点の蓄積距離のリスト
# yE: E(xk[i])。入れ子になったリストで構わない
# ktotallist: k点境界における、最初のk点からの距離の和のリスト
# ktotal_namelist: k点境界における、k点の名称
#def plot_band(axis, xk, yE, Erange, ktotallist, ktotal_namelist):
    if len(kptdic) == 0:
        ktotallist      = []
        ktotal_namelist = []
    else:
        kp = kptdic[0]
        kname0 = kp["kname0"]
        ktotallist      = [0.0]
        ktotal_namelist = [vasp.convert_kname(kname0)]
        for i in range(len(kptdic)):
            kp = kptdic[i]
            kx0, ky0, kz0, kname0 = kp["kx0"], kp["ky0"], kp["kz0"], kp["kname0"]
            kx1, ky1, kz1, kname1 = kp["kx1"], kp["ky1"], kp["kz1"], kp["kname1"]
            ktot = kp["ktot"]
            ktotallist.append(ktot)
            ktotal_namelist.append(vasp.convert_kname(kname1))
#    exit()
    
    xk   = []
    kvec = []
    yEup = []
    yNup = []
    yEdn = []
    yNdn = []
    EList = eigenvalinf['EList']
    for i in range(len(EList)):
        el = EList[i]
        kx0, ky0, kz0, wk, dk, ktot, Eups, occups, Edns, occdns = [*el]
        xk.append(ktot)
        kvec.append([kx0, ky0, kz0, dk, ktot])
        yEup.append(Eups)
        yNup.append(occups)
        if ISPIN == 2:
            yEdn.append(Edns)
            yNdn.append(occdns)

# CSVファイルを保存
    save_bandstructure(band_path, ISPIN, kvec, yEup, yEdn)

# HF/METAGGA計算の場合は、BZ端の線、k点の名称をプロットしない
#    if IsHF or IsMETAGGA:
#        ktotallist = None
#        ktotal_namelist = None

    if mode == 'bandocc':
        plot_band(plt, axband, ISPIN, xk, yEup, yEdn, [Emin, Emax], occups = yNup, occdns = yNdn, 
                ktotallist = ktotallist, ktotal_namelist = ktotal_namelist, 
                EV = EV, EC = EC, EF = EF0, EHOMO = EHOMO, ELUMO = ELUMO,
                EFlabel = '$E_{F0}$',
                marker = 'o', linestyle = '',
                markersize = band_marker_size, markeredgewidth = band_marker_edge_width,
                legendloc = 'lower right')
    elif mode == 'bandline':
        plot_band(plt, axband, ISPIN, xk, yEup, yEdn, [Emin, Emax], 
                ktotallist = ktotallist, ktotal_namelist = ktotal_namelist, 
                EV = EV, EC = EC, EF = EF0, EHOMO = EHOMO, ELUMO = ELUMO,
                EFlabel = '$E_{F0}$',
                marker = None, linestyle = '-',
                legendloc = 'lower right')
    elif mode == 'band':
        plot_band(plt, axband, ISPIN, xk, yEup, yEdn, [Emin, Emax],
                ktotallist = ktotallist, ktotal_namelist = ktotal_namelist, 
                EV = EV, EC = EC, EF = EF0, EHOMO = EHOMO, ELUMO = ELUMO,
                EFlabel = '$E_{F0}$',
                marker = 'o', linestyle = '',
                markersize = band_marker_size, markeredgewidth = band_marker_edge_width,
                legendloc = 'lower right')
    
# Rearange the graph axes so that they are not overlapped
    plt.tight_layout()

    if save_figure:
        outfig_path = 'band.png'
        print("")
        print(f"Save band structure to figure file [{outfig_path}]")
        plt.savefig(outfig_path, dpi=300, transparent=True)

    use_pause = True
    if plot_figure:
        if use_pause:
            plt.pause(0.1)
            print("")
            input("Press ENTER to exit>>")
        else:
            plt.show()
            print("")
            print("Close graph window to terminate")

    terminate("", usage = usage)


def main():
    updatevars()

    print("")
    print("=============== Plot band structure calculated by VASP ============")
    print("")
    print("mode: ", mode)

    if mode == 'band' or mode == 'bandline' or mode == 'bandocc':
        plot_band_structure(mode, CAR_dir, Emin, Emax)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


if __name__ == "__main__":
    main()


