import os
import sys
import shutil
import glob
import csv
import re
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
import tklib.tkre as tkre
from tklib.tkutils import save_csv
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tksci import Reduce01, Round
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tkvariousdata import tkVariousData
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType
from tklib.tkcrystal.tkvasp import tkVASP
from tklib.tksci import tkequation
import tklib.tkcsv


"""
Plot optical dielectric constant from OUTCAR
"""

#================================
# global parameters
#================================
debug = 0

# mode: 'density', 'current'
mode = 'density'

CAR_dir = '.'


# Plot configuration
Emin = 0.0  # eV
Emax = 8.0  # eV
#colors = ['#000000', '#ff0000', '#00aa00', '#0000ff', '#aaaa00', '#ff00ff', '#00ffff', '#aa0000', '#00aa00', '#0000aa']
#colors = ['k', 'r', 'g', 'b', 'y', 'm', 'c']
fontsize = 16
legend_fontsize = 12


#=============================
# Treat argments
#=============================
def usage():
    global mode
    global CAR_dir
    global Emin, Emax

    print("")
    print("Usage:")
    print("  (a)  python {} mode CAR_dir Emin, Emax".format(sys.argv[0]))
    print("     ex: python {} {} {} {}".format(sys.argv[0], mode, CAR_dir, Emin, Emax))
    
def updatevars():
    global mode
    global CAR_dir
    global Emin, Emax

    mode    = getarg     (1, mode)
    CAR_dir = getarg     (2, CAR_dir)
    Emin    = getfloatarg(3, Emin)
    Emax    = getfloatarg(4, Emax)
    if mode == 'density' or mode == 'current':
        pass
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


#def save_csv(path, headerlist, datalist, is_print = 0):

def read_spectrum(fp, key):
    Elist = []
    epslist = []
    for i in range(6):
        epslist.append([])

    line = fp.SkipTo(key)
#    print("key line: ", line)
    
#     E(ev)      X         Y         Z        XY        YZ        ZX
    fp.ReadLine()
#  --------------------------------------------------------------------------------------------------------------
    fp.ReadLine()
    while 1:
        line = fp.ReadLine()
        if not line:
            break

        aa = line.split()
#        print("l: ", line.strip())
#        print("aa: ", aa)
        if aa is None or len(aa) < 7:
            break

        Elist.append(pfloat(aa[0]))
        epslist[0].append(pfloat(aa[1]))
        epslist[1].append(pfloat(aa[2]))
        epslist[2].append(pfloat(aa[3]))
        epslist[3].append(pfloat(aa[4]))
        epslist[4].append(pfloat(aa[5]))
        epslist[5].append(pfloat(aa[6]))

    return Elist, epslist

def read_epsilon(OUTCAR_path):
    outcar = tkFile(OUTCAR_path, 'r')

# frequency dependent IMAGINARY DIELECTRIC FUNCTION (independent particle, no local field effects) density-density
    Elist, e2list1 = read_spectrum(outcar, r"IMAGINARY DIELECTRIC FUNCTION.+density")
#    print("Elist = ", Elist)
#    print("e2list = ", e2list1)

# frequency dependent      REAL DIELECTRIC FUNCTION (independent particle, no local field effects) density-density
    Elist, e1list1 = read_spectrum(outcar, r"REAL DIELECTRIC FUNCTION.+density")

# frequency dependent IMAGINARY DIELECTRIC FUNCTION (independent particle, no local field effects) current-current (2nd set in vasprun.xml)
    Elist, e2list2 = read_spectrum(outcar, "IMAGINARY DIELECTRIC FUNCTION.+current")

# frequency dependent      REAL DIELECTRIC FUNCTION (independent particle, no local field effects) current-current (2nd set in vasprun.xml)
    Elist, e1list2 = read_spectrum(outcar, "REAL DIELECTRIC FUNCTION.+current")

    outcar.Close()
    
    return Elist, e1list1, e2list1, e1list2, e2list2

def plot_epsilon(mode, CAR_path, Emin, Emax):
    vasp = tkVASP()
    
    base_path = vasp.getdir(CAR_path)

    print("")
    print("mode: ", mode)
    INCAR_path   = vasp.get_INCAR(base_path)
    POSCAR_path  = vasp.get_POSCAR(base_path)
    CONTCAR_path = vasp.get_CONTCAR(base_path)
    OUTCAR_path  = vasp.get_OUTCAR(base_path)
    print("CAR dir : ", CAR_path)
    print("INCAR   : ", INCAR_path)
    print("POSCAR  : ", POSCAR_path)
    print("CONTCAR : ", CONTCAR_path)
    print("OUTCAR  : ", OUTCAR_path)

    print("")
    Elist, e1list1, e2list1, e1list2, e2list2 = read_epsilon(OUTCAR_path)
#    print("Elist = ", Elist)

    E_min = min(Elist)
    E_max = max(Elist)
    nE = len(Elist)
    print("Data E range: {} - {} eV, {} points".format(E_min, E_max, nE))
    Emin = max([Emin, E_min])
    Emax = min([Emax, E_max])
    print("Plot E range: {} - {} eV".format(Emin, Emax))

    if mode == 'current':
        e1list1 = e1list2
        e2list1 = e2list2

    outxlsx = 'epsilon.xlsx'
    print("")
    print(f"Save epsilon data to [{outxlsx}]")
    tkVariousData().to_excel(outxlsx,
            ["E(eV)", "e1xx", "e2xx", "e1yy", "e2yy", "e1zz", "e2zz", "e1yz", "e2yz", "e1zx", "e2zx", "e1xy", "e2xy", ], 
            [Elist, e1list1[0], e2list1[0], e1list1[1], e2list1[1], e1list1[2], e2list1[2],
                    e1list1[3], e2list1[3], e1list1[4], e2list1[4], e1list1[5], e2list1[5]])

#=============================
# Plot graphs
#=============================
    print("")
    label = ['XX', 'YY', 'ZZ', 'XY', 'YZ', 'ZX']
#    plt.title("mode: {}".format(mode))

    fig = plt.figure(figsize = (12, 8))

    ax = make_matrix1(6)
    ax[0] = fig.add_subplot(2, 3, 1)
    ax[1] = fig.add_subplot(2, 3, 2)
    ax[2] = fig.add_subplot(2, 3, 3)
    ax[3] = fig.add_subplot(2, 3, 4)
    ax[4] = fig.add_subplot(2, 3, 5)
    ax[5] = fig.add_subplot(2, 3, 6)

    for i in range(6):
        ax[i].plot(Elist, e1list1[i], label = '$\\epsilon_1$_$_{%s}$' % (label[i]), color = 'blue', linewidth = 0.5)
        ax[i].plot(Elist, e2list1[i], label = '$\\epsilon_2$_$_{%s}$' % (label[i]), color = 'red',  linewidth = 0.5)
        ax[i].set_xlim(Emin, Emax)
        ax[i].set_xlabel("$E$ (eV)", fontsize = fontsize)
        ax[i].set_ylabel("$\\epsilon$", fontsize = fontsize)
        ax[i].legend(fontsize = legend_fontsize)
        ax[i].tick_params(labelsize = fontsize)


# Rearange the graph axes so that they are not overlapped
    plt.tight_layout()

    use_pause = True
    
    if use_pause:
        plt.pause(0.1)
        print("")
        print("Press ENTER to exit>>", end = '')
        input()
    else:
        plt.show()
        print("")
        print("Close graph window to terminate")

    terminate("", usage = usage)


def main():
    global mode
    global cifpath, poscarpath
    
    updatevars()

    print("")
    print("=============== Plot optical dielectric function calculated by VASP with LEPSILON=.TRUE. ============")
    print("")
    print("mode: ", mode)

    if mode == 'density' or mode == 'current':
        plot_epsilon(mode, CAR_dir, Emin, Emax)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


if __name__ == "__main__":
    main()
