import os
import sys
import time
import shutil
import glob
import csv
import re
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
import tklib.tkre as tkre
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType
from tklib.tkcrystal.tkvasp import tkVASP


"""
Watch progress of VASP run
"""

#================================
# global parameters
#================================
debug = 0

mode = 'relax'

CAR_dir = '.'

t_sleep = 1.0 # s

# scaling power factor to plot dE
npower = 10.0

# Plot configuration
figsize = (8, 6)
fontsize        = 14
labelfontsize   = 12
legend_fontsize = 6


#=============================
# Treat argments
#=============================
def usage():
    print("")
    print("Usage:")
    print("  (a)  python {} relax CAR_dir cif_path".format(sys.argv[0]))
    print("     ex: python {} {} {}".format(sys.argv[0], mode, CAR_dir))

def updatevars():
    global mode
    global CAR_dir

    mode    = getarg(1, mode)   
    CAR_dir = getarg(2, CAR_dir)   
    if mode == 'relax':
        pass
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


iter = 0
def update_relax(mode, CAR_path, update_time_prev, fig_conv, ax_logE, ax_dE, ax_F):
    global iter
    
    vasp = tkVASP()

    base_path = vasp.getdir(CAR_path)

    INCAR_path   = vasp.get_INCAR(base_path)
    POSCAR_path  = vasp.get_POSCAR(base_path)
    KPOINTS_path  = vasp.get_VASPPath(base_path, 'KPOINTS')
    OUTCAR_path  = vasp.get_OUTCAR(base_path)
    OSZICAR_path = vasp.get_VASPPath(base_path, 'OSZICAR')

    if not os.path.exists(OUTCAR_path):
        print(f"Warning in update_relax(): Can not read [{OUTCAR_path}]")
        return None

    if not os.path.exists(OSZICAR_path):
        print(f"Warning in update_relax(): Can not read [{OSZICAR_path}]")
        return None

    update_time =  os.path.getmtime(OUTCAR_path)
    st          = time.localtime(update_time)
    try:
        stp = time.localtime(update_time_prev)
    except:
        stp = None

    if update_time_prev is None or update_time_prev < update_time:
        iter += 1
        print("")
        print("*** OUTCAR [{}] is updated (#{})".format(OUTCAR_path, iter))
#    print("time: ", update_time_prev, update_time)
        if update_time_prev is None:
            print("time: ", 'None', '{}:{}:{}'.format(st.tm_hour, st.tm_min, st.tm_sec))
        else:
            print("time: ", '{}:{}:{}'.format(stp.tm_hour, stp.tm_min, stp.tm_sec), '{}:{}:{}'.format(st.tm_hour, st.tm_min, st.tm_sec))
    else:
#        print("*** OUTCAR [{}] is not updated (#{})".format(OUTCAR_path, iter))
        return update_time

    print("")
    print("CAR dir   : ", CAR_path)
    print("  INCAR   : ", INCAR_path)
    print("  POSCAR  : ", POSCAR_path)
    print("  KPOINTS : ", KPOINTS_path)
#    print("  CONTCAR : ", CONTCAR_path)
    print("  OUTCAR  : ", OUTCAR_path)
    print("  OSZICAR : ", OSZICAR_path)
#    print("  EIGENVAL: ", EIGENVAL_path)
#    print("  DOSCAR  : ", DOSCAR_path)

    print("")
    print("Read INCAR from [{}]".format(INCAR_path))
    try: 
        incarinf = vasp.read_incar_inf(INCAR_path)
        if incarinf:
            if incarinf['SYSTEM']:
                print("  SYSTEM: ", incarinf['SYSTEM'])
    except:
        print("  {} does not exist.".format(INCAR_path))
        pass

    print("")
    print("Read POSCAR from [{}]".format(POSCAR_path))
    try: 
        cry = vasp.read_poscar(POSCAR_path)
        if cry:
#            cry.PrintInf()
            latt = cry.LatticeParameters()
            print("  latt: {:8.6g}, {:8.6g}, {:8.6g}, {:8.6g}, {:8.6g}, {:8.6g}".format(*latt))
    except:
        print("  {} does not exist.".format(POSCAR_path))
        pass

    print("")
    print("Read OUTCAR from [{}]".format(OUTCAR_path))
    EDIFF  = None
    EDIFFG = None
    try: 
        outcarinf = vasp.read_outcar_inf(OUTCAR_path)
        if outcarinf:
            ISPIN  = outcarinf["ISPIN"]
            IsHF   = outcarinf["IsHF"]
            EDIFF  = outcarinf["EDIFF"]
            EDIFFG = outcarinf["EDIFFG"]
            EF     = outcarinf["EF"]
            print("Information in OUTCAR:")
            print("  ISPIN  : ", ISPIN)
            print("  IsHF: ", IsHF)
            print("  EF = {} eV".format(EF))
            print("  EDIFF  = {} eV".format(EDIFF))
            print("  EDIFFG = {} eV".format(EDIFFG))
    except:
        print("  {} does not exist".format(OUTCAR_path))
        pass

    if EDIFF is None:
        return update_time

    print("")
    print("Read OSZICAR_path from [{}]".format(OSZICAR_path))
    oszicarinf = vasp.read_oszicar_inf(OSZICAR_path)
    if oszicarinf:
        ion = oszicarinf['ion_steps']
        all = oszicarinf['all_steps']

        Efin = all[len(all)-1]['E']
        xiter = []
        ylogE = []
        ydE   = []
        iter = 0
        print("  {:>3}:{:>3}:{:>3}:  {:12} {:12}".format('itr', 'ion', 'e', 'E(eV)', 'dE(eV)'))
        for i in range(len(all)):
            inf = all[i]
            print("  {:3d}:{:3d}:{:3d}: {:12.8g} {:12.8g}".format(iter, inf['ion_step'], inf['electron_step'], inf['E'], inf['dE']))
            dE   = inf['E'] - Efin
            dEa  = abs(dE)
            dEan = pow(dEa, 1.0 / npower)

            xiter.append(iter)
            ylogE.append(dEa)
            if dE > 0.0:
                ydE.append(dEan)
            else:
                ydE.append(-dEan)

            iter += 1

        Ffin = ion[len(ion)-1]['F']
        xioniter = []
        ydF      = []
        iter = 0
        netot = 0
        print("  {:>3}:{:>3}:{:>3}: {:12} {:12}".format('ion', 'ne', 'tot', 'F(eV/A)', 'E0(eV)'))
        for i in range(len(ion)):
            inf = ion[i]
#            print("i,inf=", i, inf)
#            netot += inf['nelectron_steps']
            netot += 1
            print("  {:3d}:{:3d}:{:3d}: {:12.8g} {:12.8g}".format(iter, inf.get('nelectron_steps', 0), netot, inf['F'], inf['E0']))

            dF   = inf['F'] - Ffin
            dFa  = abs(dF)
            dFan = pow(dFa, 1.0 / npower)
            xioniter.append(netot - 1)
            if dFan > 0.0:
                ydF.append(dFan)
            else:
                ydF.append(dFan)
            iter += 1

# Update plots
#        print("Clear graphs")
        ax_logE.cla()
        ax_dE.cla()
        ax_F.cla()

        xlim = [0, max(xiter)+1]
        l1, = ax_logE.plot(xiter, ylogE, marker = 'o', markersize = 6.0, label = 'dE')
        l2, = ax_logE.plot(xlim, [EDIFF,  EDIFF],  label = 'EDIFF', color = 'blue', linestyle = 'dashed', linewidth = 0.5)
        if EDIFFG > 0.0:
            ax_logE.plot(xlim, [EDIFFG, EDIFFG], label = 'EDIFFG', color = 'green', linestyle = 'dashed', linewidth = 0.5)
        ax_logE.set_xlim(xlim)
        ax_logE.set_xlabel("Iteration", fontsize = fontsize)
        ax_logE.set_ylabel("$dE$ (eV)",  fontsize = fontsize)
        ax_logE.set_yscale('log')
#        ax_logE.legend(fontsize = legend_fontsize)
        ax_logE.tick_params(labelsize = fontsize)

        l3, = ax_dE.plot(xiter, ydE, marker = 'o', markersize = 6.0, label = 'dE^(1/{})'.format(npower))
        l4, = ax_dE.plot(xlim, [0.0, 0.0], color = 'red', linestyle = 'dashed', linewidth = 0.5)
        EDIFFn = pow(EDIFF, 1.0 / npower)
        l5, = ax_dE.plot(xlim, [ EDIFFn,  EDIFFn], label = 'EDIFF', color = 'blue', linestyle = 'dashed', linewidth = 0.5)
        l6, = ax_dE.plot(xlim, [-EDIFFn, -EDIFFn], label = 'EDIFF', color = 'blue', linestyle = 'dashed', linewidth = 0.5)
        if EDIFFG > 0.0:
            EDIFFGn = pow(EDIFFG, 1.0 / npower)
            ax_dE.plot(xlim, [ EDIFFGn,  EDIFFGn], label = 'EDIFFG', color = 'green', linestyle = 'dashed', linewidth = 0.5)
            ax_dE.plot(xlim, [-EDIFFGn, -EDIFFGn], label = 'EDIFFG', color = 'green', linestyle = 'dashed', linewidth = 0.5)
        ax_dE.set_xlim(xlim)
        ax_dE.set_xlabel("Iteration", fontsize = fontsize)
        ax_dE.set_ylabel("$dE^{1/%g}$ (eV)" % (npower),  fontsize = fontsize)
#        ax_dE.legend(fontsize = legend_fontsize)
        ax_dE.tick_params(labelsize = fontsize)

        l7, = ax_F.plot(xioniter, ydF, marker = 'o', markersize = 6.0, label = 'dF^(1/{})'.format(npower))
        l8, = ax_F.plot(xlim, [0.0, 0.0], color = 'red', linestyle = 'dashed', linewidth = 0.5)
        if EDIFFG < 0.0:
            EDIFFGn = pow(-EDIFFG, 1.0 / npower)
            ax_F.plot(xlim, [+EDIFFGn, +EDIFFGn], label = 'EDIFFG', color = 'green', linestyle = 'dashed', linewidth = 0.5)
            ax_F.plot(xlim, [-EDIFFGn, -EDIFFGn], label = 'EDIFFG', color = 'green', linestyle = 'dashed', linewidth = 0.5)
        ax_F.set_xlim(xlim)
        ax_F.set_xlabel("Iteration", fontsize = fontsize)
        ax_F.set_ylabel("$dF^{1/%g}$ (eV/A)" % (npower),  fontsize = fontsize)
#        ax_F.legend(fontsize = legend_fontsize)
        ax_F.tick_params(labelsize = fontsize)

# Rearange the graph axes so that they are not overlapped
        plt.tight_layout()

#        print("plot")
        plt.pause(0.001)
#        for l in [l1, l2, l3, l4, l5, l6, l7,l8]:
#            l.remove()

        print("")
        print("Press ctrl-C to terminate")

    else:
        print("  {} does not exist".format(OSZICAR_path))
        pass

    return update_time


def watch_relax(mode, CAR_dir):
#グラフのプロット
    fig_conv = plt.figure(figsize = figsize)
    ax_logE = fig_conv.add_subplot(3, 1, 1)
    ax_dE   = fig_conv.add_subplot(3, 1, 2)
    ax_F    = fig_conv.add_subplot(3, 1, 3)
#    ax_logE.plot([0, 1], [0, 1])
#    plt.pause(0.001)
#    input(">>")
    
    update_time = None
    try:
        while 1:
#            print(f"check update t_sleep={t_sleep}")
            update_time = update_relax(mode, CAR_dir, update_time, fig_conv, ax_logE, ax_dE, ax_F)
            for i in range(int(t_sleep * 10.0 + 0.99999)):
                time.sleep(0.1)
            plt.pause(0.001)
    except KeyboardInterrupt:
        pass


    terminate(None, usage = usage)


def main():
    updatevars()

    print("")
    print("=============== Watch progress of VASP run ============")
    print("")
    print("mode   : ", mode)

    if mode == 'relax':
        watch_relax(mode, CAR_dir)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


if __name__ == "__main__":
    main()
