import os
import sys
import shutil
import glob
import csv
import re
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
import tklib.tkre as tkre
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType


"""
Manage VASP output files
"""

#================================
# global parameters
#================================
debug = 0

# mode: 'poscar2cif', 'outcar2cif'
#mode = 'poscar2cif'
mode = 'outcar2cif'

# input poscar path
poscarpath = 'POSCAR'
# input outcar path
outcarpath = 'OUTCAR'

# output cif path
cifpath = ''

#=============================
# Treat argments
#=============================
def usage():
    global mode
    global poscarpath, cifpath

    print("")
    print("Usage:")
    print("  (a)  python {} poscar2cif poscar_path cif_path".format(sys.argv[0]))
    print("     ex: python {} {} {} {}".format(sys.argv[0], 'poscar2cif', poscarpath, cifpath))
    print("  (a)  python {} outcar2cif outcar_path cif_path".format(sys.argv[0]))
    print("     ex: python {} {} {} {}".format(sys.argv[0], 'outcar2cif', outcarpath, cifpath))

def updatevars():
    global mode
    global cifpath, poscarpath, outcarpath

    mode = getarg(1, mode)   
    if mode == 'poscar2cif':
        poscarpath = getarg(2, poscarpath)
        cifpath    = getarg(3, cifpath)
    elif mode == 'outcar2cif':
        outcarpath = getarg(2, outcarpath)
#        cifpath    = getarg(3, cifpath)
    else:
        terminate("Error in updatevars(): Invalide mode [{}]".format(mode), usage = usage, pause = True)


def read_poscar(poscarpath, save_cif = True):
    infp = tkFile(poscarpath, 'r')
    if not infp:
        terminate("Error in poscar2cif: Can not read [{}]".format(poscarpath), usage = usage, pause = True)

    sample_name = infp.ReadLine().strip()
    print("sample name: [{}]".format(sample_name))

    if save_cif:
        print("Write to [{}]".format(cifpath))
        outfp = tkFile(cifpath, 'w')
        if not outfp:
            terminate("Error in poscar2cif: Can not write to [{}]".format(cifpath), usage = usage, pause = True)

    a0 = pfloat(infp.ReadLine())
    aij = []
    aij.append(infp.ReadLine().split())
    aij.append(infp.ReadLine().split())
    aij.append(infp.ReadLine().split())
    for i in range(3):
        for j in range(3):
            aij[i][j] = pfloat(aij[i][j])

    print("")
    print("Base lattice parameter: {} A".format(a0))
    print("Lattice vectors:")
    for i in range(3):
        print("  {:16.12f}  {:16.12f}  {:16.12f}".format(*aij[i]))

    atomtypes = infp.ReadLine().split()
    natoms = infp.ReadLine().split()
    for i in range(len(natoms)):
        natoms[i] = pint(natoms[i])
    print("")
    print("Atom types and numbers:")
    for i in range(len(atomtypes)):
        print("  {:>4}: {:4d}".format(atomtypes[i], natoms[i]))

    sitetypes = []
    for ia in range(len(atomtypes)):
        for n in range(natoms[ia]):
            sitetypes.append(atomtypes[ia])

#Selective dynamics
    line = infp.ReadLine().strip()
#Direct
    if line != 'direct':
        line = infp.ReadLine()
    
    print("")
    print("Atom positions:")
    pos = []
    count = 0
    while 1:
        line = infp.ReadLine()
        if not line:
            break

        pos0 = line.split()
        if len(pos0) < 3:
            break

        pos.append([pfloat(pos0[0]), pfloat(pos0[1]), pfloat(pos0[2])])
        print("  {:4}: ({:16.12f}, {:16.12f}, {:16.12f})"
            .format(sitetypes[count], pos[count][0], pos[count][1], pos[count][2]))

        count += 1

    if save_cif: outfp.Close()

    print("")
    print("Build Crystal object:")
    cry = tkCrystal()
    cry.SetSampleName(sample_name)
    cry.SetCrystalName(sample_name)
    cry.SetLatticeVectors(a0 * np.array(aij))
    latt = cry.LatticeParameters()
    for i in range(len(pos)):
        cry.AddAtomSite(name = sitetypes[i], pos = pos[i])

    cry.ExpandCoordinates()

    cry.PrintInf()

    return cry

def read_outcar_inf(outcarpath):
    fp = tkFile(outcarpath, 'r')
    if not fp:
        return None

    inf = {}
    pos = 0

# read input parameters
    line = fp.SkipTo(r'^\s*SYSTEM\s*=')
    if line:
        match = tkre.Search(r'SYSTEM\s*=\s*(\S.*?)\s*$', line)
        if match:
            inf['SYSTEM'] = match[1]

# read ion step variation of convergedtotal energy 
    count = 0
    energy = []
    while 1:
        line = fp.SkipTo("\\s*free  energy   TOTEN  =");
        if not line:
            break
        
        pos = fp.Tell()
        s = tkre.Search(r'free  energy   TOTEN  =\s*(\S+)', line)

        if pos > 0:
            fp.Seek(pos, 0)

        key = 'FreeEnergy{}'.format(count)
        energy.append(pfloat(s[1]))
#        print("E=", energy[count])
        count += 1
    inf['FreeEnergy'] = energy

# read final charges
    line = fp.SkipTo("\\s*total charge", 0)
    line = fp.SkipTo("# of ion     s       p       d       tot")
    line = fp.ReadLine()
    count = 0
    chargeinf = []
    while 1:
        line = fp.ReadLine()
        if not line:
            break
        line = line.strip()
        if tkre.Match(r'^-----', line):
            break
            
        iIon, s, p, d, tot, = line.split()
#        print("i=", iIon, s, p, d, tot)
        chargeinf.append({
                 's'.format(iIon): pfloat(s)
                ,'p'.format(iIon): pfloat(p)
                ,'d'.format(iIon): pfloat(d)
                ,'tot'.format(iIon): pfloat(tot)
                })
        count += 1                
    inf['FinalCharges'] = chargeinf
    inf['nIon'] = count
                
    return inf

def read_next_crystalstructure(fp, cry):
    latt = cry.LatticeParameters()

    line = fp.SkipTo('aborting loop because EDIFF is reached')
    if not line: return None

    line = fp.SkipTo('VOLUME and BASIS-vectors')
    fp.readline()
    fp.readline()
    fp.readline()
    fp.readline()
    aij = []
    aij.append(fp.ReadLine().split()[:3])
    aij.append(fp.ReadLine().split()[:3])
    aij.append(fp.ReadLine().split()[:3])
    for i in range(3):
        for j in range(3):
            aij[i][j] = pfloat(aij[i][j])
#    print("aij=", aij)

    print("")
    print("Lattice vectors:")
    for i in range(3):
        print("  {:16.12f}  {:16.12f}  {:16.12f}".format(*aij[i]))
    cry.SetLatticeVectors(aij)
    latt = cry.LatticeParameters()
    print("lattice parameters:", *latt)
    
    line = fp.SkipTo('POSITION')
# skip ' ------...'
    line = fp.ReadLine()

    count = 0
    while 1:
        line = fp.ReadLine()
        if not line or '-----' in line:
            break
        x, y, z, fx, fy, fz, = line.split()
        x = pfloat(x) / latt[0]
        y = pfloat(y) / latt[1]
        z = pfloat(z) / latt[2]
        cry.SetAtomSite(count, pos = [x, y, z], 
                               force = [pfloat(fx), pfloat(fy), pfloat(fz)])
        count += 1

    cry.ExpandCoordinates()

    return cry

def outcar2cif(poscarpath, outcarpath, cifpath):
    print("")
    print("POSCAR path: ", poscarpath)
    print("OUTCAR path: ", outcarpath)
    print("CIF path: ", cifpath)

    print("")
    print("Read initial crystal structure from [{}]".format(poscarpath))
    cry = read_poscar(poscarpath, save_cif = False)
    cry.PrintInf()

    print("")
    print("Read [{}]".format(outcarpath))
    inf = read_outcar_inf(outcarpath)
    if not inf:
        terminate("Error: Can not read [{}]".format(outcarpath), usage = usage, pause = True)

    sample_name = inf['SYSTEM']
    if not sample_name:
        sample_name = 'hogehoge'
    if not cifpath:
        cifpath = sample_name + '-outcar.cif'
        print("CIF path: ", cifpath)

    print("Energy vs ion step:")
    einf = inf['FreeEnergy']
    for i in range(len(einf)):
        print("  {:3d}: {:12.8g} eV".format(i, einf[i]))
    print("Energy vs step:")
    cinf = inf['FinalCharges']
    for i in range(len(cinf)):
        charges = cinf[i]
        print("  {:3d}: s={:10.6f} p={:10.6f} d={:10.6f} tot={:10.6f}"
            .format(i, charges['s'], charges['p'], charges['d'], charges['tot']))

    print("")
    print("Read crystal structures from [{}]".format(outcarpath))
    fp = tkFile(outcarpath, 'r')
    if not fp:
        terminate("Error: Can not read [{}]".format(outcarpath), usage = usage, pause = True)

    count = 1
    while 1:
        cry = read_next_crystalstructure(fp, cry)
        if not cry: break

        print("")
        print("Structure #{}".format(count))
        cry.PrintInf()

        cifpath = f"relax-{count:04d}.cif"
        print("")
        print("Write to [{}]".format(cifpath))
        cif = tkCIFData()
        cif.CreateCIFFileFromCCrystal(cry, cifpath)

        count += 1

    fp.Close()

    terminate(None, usage = usage, pause = True)


def poscar2cif(poscarpath, cifpath):
    print("")
    print("POSCAR path: ", poscarpath)
    print("CIF path: ", cifpath)

    print("")
    print("Read [{}]".format(poscarpath))
    cry = read_poscar(poscarpath)

    if cifpath == '':
        cifpath = cry.SampleName() + '-converted.cif'

    print("")
    print("Write to [{}]".format(cifpath))
    cif = tkCIFData()
    cif.CreateCIFFileFromCCrystal(cry, cifpath)

    terminate(None, usage = usage, pause = True)


def main():
    global mode
    global poscarpath, outcarpath, cifpath
    
    updatevars()

    print("")
    print("=============== Manage VASP output files ============")
    print("")
    print("mode: ", mode)

    if mode == 'poscar2cif':
        poscar2cif(poscarpath, cifpath)
    elif mode == 'outcar2cif':
        outcar2cif(poscarpath, outcarpath, cifpath)
    else:
        terminate("Error in main(): Invalide mode [{}]".format(mode), usage = usage, pause = True)


if __name__ == "__main__":
    main()
