import os
import sys
import shutil
import copy
import glob
import csv
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi


from tklib.tkfile import tkFile
import tklib.tkre as tkre
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tksci import Reduce01, Round
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkvasp import tkVASP
from tklib.tkapplication import tkApplication


"""
Convert VASP MD output files to XCrySDen AXSF file
"""

#================================
# global parameters
#================================
debug = 0

# mode: 'msd', 'imsd'
mode = 'msd'
#mode = 'imsd'

# input path
CAR_path     = 'OUTCAR'

cifpath = 'test.cif'
historycsvfile = None

# cif read configuration
single = 1
findvalidstructure = 1


app = tkApplication()


#=============================
# Treat argments
#=============================
def usage():
    global mode, CAR_path

    print("")
    print("Usage:")
    print("  (a)  python {} mode outcar_path".format(sys.argv[0]))
    print("         mode: msd: standard mean square displacement <ri(t) - ri(0)>")
    print("               imsd: integrated msd sum_t{<ri(t) - ri(t-dt)>}")
    print("     ex: python {} {} {}".format(sys.argv[0], mode, CAR_path))
    
    
def updatevars():
    global mode, CAR_path

    mode     = getarg(1, mode)
    CAR_path = getarg(2, CAR_path)


def make_history():
    global mode, CAR_path
    
    vasp = tkVASP()
    
    CAR_path = vasp.getdir(CAR_path)

    print("")
    print("mode   : ", mode)
    if mode == 'imsd':
        fimsd = 1
    else:
        fimsd = 0
    INCAR_path   = vasp.get_INCAR(CAR_path)
    POSCAR_path  = vasp.get_POSCAR(CAR_path)
    CONTCAR_path = vasp.get_CONTCAR(CAR_path)
    OUTCAR_path  = vasp.get_OUTCAR(CAR_path)
    print("CAR dir: ", CAR_path)
    print("INCAR  : ", INCAR_path)
    print("POSCAR : ", POSCAR_path)
    print("CONTCAR: ", CONTCAR_path)
    print("OUTCAR : ", OUTCAR_path)

    print("")
    print("*** Read initial crystal structure from [{}]".format(POSCAR_path))
    initial_cry = vasp.read_poscar(POSCAR_path)
#    initial_cry.PrintInf("cell")
    a, b, c, alpha, beta, gamm = initial_cry.LatticeParameters()
    print("cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamm))

    print("")
    print("*** Read final crystal structure from [{}]".format(CONTCAR_path))
    final_cry = vasp.read_poscar(CONTCAR_path)
#    final_cry.PrintInf("cell")
    a, b, c, alpha, beta, gamm = final_cry.LatticeParameters()
    print("cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamm))

    print("")
    print("Read [{}]".format(INCAR_path))
    incar_inf = vasp.read_incar_inf(INCAR_path)
    if not incar_inf:
        app.terminate("Error: Can not read [{}]".format(INCAR_path), pause = True)

    sample_name = incar_inf['SYSTEM']
    if not sample_name:
        sample_name = 'hogehoge'
    dt = pfloat(incar_inf['POTIM'])

    print("")
    print("sample_name:", sample_name)
    print("POTIM: {} fs", dt)

    a = sample_name.split()
    sample_name = a[0]
    initial_cifpath = sample_name + '-initial.cif'
    final_cifpath   = sample_name + '-final.cif'
    historycsvfile  = sample_name + '-history.csv'

    cif = tkCIFData()
    print("")
    print("Write to [{}]".format(initial_cifpath))
    cif.CreateCIFFileFromCCrystal(initial_cry, initial_cifpath)
    print("Write to [{}]".format(final_cifpath))
    cif = tkCIFData()
    cif.CreateCIFFileFromCCrystal(final_cry, final_cifpath)

    print("")
    print("Write history to [{}]".format(historycsvfile))
    hist = tkFile(historycsvfile, 'w')
    if not hist:
        app.terminate("Error: Can not write to [{}]".format(historycsvfile), pause = True)

    nstructures = vasp.read_n_crystalstructures(OUTCAR_path)
    print("# of crystal structures:", nstructures)

    print("")
    print("Read crystal structures from [{}]".format(OUTCAR_path))
    fp = tkFile(OUTCAR_path, 'r')
    if not fp:
        app.terminate("Error: Can not read [{}]".format(outcarpath), pause = True)

    types        = initial_cry.AtomTypeList()
    ntypes       = len(types)
    sites        = initial_cry.ExpandedAtomSiteList()
    nsites       = len(sites)
    poslist0     = initial_cry.get_position_list(mode = 'all', IsReduce01 = True)
    atomnamelist = initial_cry.get_atom_name_list(mode = 'all', NameOnly = True)
    offset       = make_matrix2(nsites, 3, defval = 0.0)
    nsitesfortypes = make_matrix1(ntypes, defval = 0)
    for i in range(nsites):
        itype = sites[i].iAtomType()
        nsitesfortypes[itype] += 1
    print("# of atom sites for types = ", nsitesfortypes)

    hist.Write("step,t(fs),T(K)"
            + ",Etot, EKIN, EKIN_LAT, ETOTAL"
            + ",a,b,c,alpha,beta,gamma,V"
            )
    if fimsd:
        msdlabel = 'imsd'
    else:
        msdlabel = 'msd'
    for i in range(ntypes):
        name = types[i].AtomType()
        hist.Write(",{}({})(A^2)".format(msdlabel, name))
    hist.Write("\n")

    count = 0
    prevposlist = poslist0
    prevmsd = make_matrix1(ntypes, defval = 0.0)
    while 1:
        inf, cry = vasp.read_next_crystalstructure(fp, initial_cry, IsReduce01 = True)
        if not cry:
            break

        t = dt * count
        a, b, c, alpha, beta, gamm = cry.LatticeParameters()
        V = cry.Volume()
        print("step #{:05d}: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}"
                .format(count, a, b, c, alpha, beta, gamm))

        msd = make_matrix1(ntypes, defval = 0.0)
        poslist = cry.get_position_list(mode = 'all', IsReduce01 = True)
        for i in range(nsites):
            if fimsd:
                x0, y0, z0 = prevposlist[i]
            else:
                x0, y0, z0 = poslist0[i]
            x, y, z    = poslist[i]
            itype      = sites[i].iAtomType()

# wrapされた座標から、実際の変位量を計算する
# ただし、一単位格子以上の変位があると、変位量は不明になる。msdの計算はあきらめる
            r = cry.GetNearestInterAtomicDistance([x0, y0, z0], [x, y, z], AllowZero = True)

            msd[itype] += r * r / nsitesfortypes[itype]
#            print("{}:{}:({:8.4f}, {:8.4f}, {:8.4f})-({:8.4f}, {:8.4f}, {:8.4f}) r={:12.4g} msd={:12.4g}".format(i, itype, x0, y0, z0, x, y, z, r, msd[itype]))

        if fimsd:
            for itype in range(ntypes):
                msd[itype] += prevmsd[itype]

        prevposlist = poslist
        prevmsd = msd

        hist.Write("{},{},{}".format(count, t, inf["temperature"])
                 + ",{},{},{},{}".format(inf["TOTEN"], inf["EKIN"], inf["EKIN_LAT"], inf["ETOTAL"])
                 + ",{},{},{},{},{},{},{}".format(a, b, c, alpha, beta, gamm, V)
                 )
        for i in range(ntypes):
            hist.Write(",{}".format(msd[i]))
        hist.Write("\n")

        count += 1

    fp.Close()

    hist.Close()

    app.terminate(pause = True)


def main():
    updatevars()

    vasp = tkVASP()
    base_path = vasp.getdir(CAR_path)
#    logfile = app.replace_path(infile)
    logfile     = os.path.join(base_path, 'make_md_hisotry-out.txt')
    print("")
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    print("")
    print("=============== Convert VASP output files to XCrySDen AXSF file ============")
    print("")

    make_history()


if __name__ == "__main__":
    main()
