#import os
import sys
#import shutil
import glob
#import csv
#import re
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
#from scipy.interpolate import interp1d
#from pprint import pprint
#from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tkcrystal.tkcif import tkCIF
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType


"""
Manage VASP input files
"""

#================================
# global parameters
#================================
debug = 0

# mode: 'cif2poscar'
mode = 'cif2poscar'

# input cif path
cifpath    = '*.cif'
# output poscar path
poscarpath = 'POSCAR'

# cif read configuration
single = 1
findvalidstructure = 1


#=============================
# Treat argments
#=============================
def usage():
    global mode
    global cifpath, poscarpath

    print("")
    print("Usage:")
    print("  (a)  python {} cif2poscar cif_path POSCAR_path".format(sys.argv[0]))
    print("     ex: python {} {} {} {}".format(sys.argv[0], 'cif2poscar', cifpath, poscarpath))

def updatevars():
    global mode
    global cifpath, poscarpath

    mode = getarg(1, mode)   
    if mode == 'cif2poscar':
        cifpath    = getarg(2, cifpath)
        poscarpath = getarg(3, poscarpath)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage)


def cif2poscar(cifpath, poscarpath):
# if cifpath is not a valied path, assume it represents wild card like *.cif.
# Search file list by the wild card and take the first file of the sorted file list

    print("")
    print("cif path: ", cifpath)
    if not IsFile(cifpath):
        flist = glob.glob(cifpath)
        try:
            flist.sort()
            cifpath = flist[0]
        except:
            terminate("Error: Invalid ciffile [{}]".format(cifpath), usage = usage)

    if not IsFile(cifpath):
        terminate("Error: Invalid ciffile [{}]".format(cifpath), usage = usage)

    dirname, basename, sample_name, ext = SplitFilePath(cifpath)
    print("sample name: ", sample_name)
    
    print("")
    print("Read [{}]".format(cifpath))

    cif = tkCIF()
    cif.debug = debug
    cifdata = cif.ReadCIF(cifpath, find_valid_structure = findvalidstructure)
    cif.Close()
    if not cifdata:
        terminate("Error: Could not get cifdat from ciffile [{}]".format(ciffile), usage = usage)

    cry = cifdata.GetCrystal()
#    cry.PrintInf()

    a, b, c, alpha, beta, gamm = cry.LatticeParameters()
    aij = cry.LatticeVectors() / a
    AtomTypes = cry.AtomTypeList()
    nAtomTypes = len(AtomTypes)
    AtomSites = cry.ExpandedAtomSiteList()
    nAtomSites = len(AtomSites)

    print("Write to [{}]".format(poscarpath))
    out = tkFile(poscarpath, 'w')
    if not out:
        terminate("Error: Could not write to [{}]".format(poscarpath), usage = usage)

    out.Write(sample_name + '\n')
    out.Write("{}\n".format(a))
    out.Write("  {:17.14f}  {:17.14f}  {:17.14f}\n".format(*aij[0]))
    out.Write("  {:17.14f}  {:17.14f}  {:17.14f}\n".format(*aij[1]))
    out.Write("  {:17.14f}  {:17.14f}  {:17.14f}\n".format(*aij[2]))

    for i in range(nAtomTypes):
        t = AtomTypes[i]
#.AtomType() returns the full name of ion etc, like Ca2+
#        typea = t.AtomType()
#.AtomTypeOnly() returns the name of element, like Ca
        typeo = t.AtomTypeOnly()
#        charge = t.Charge()
        out.Write("{:>6} ".format(typeo))
    out.Write("\n")

    for i in range(nAtomTypes):
        t = AtomTypes[i]
        typeo = t.AtomTypeOnly()
        nat = cry.count_by_type(typeo)
        out.Write("{:6} ".format(nat))
    out.Write("\n")

    out.Write("Selective dynamics\n")
    out.Write("Direct\n")
    for it in range(nAtomTypes):
        for isite in range(nAtomSites):
            atom = AtomSites[isite]
            iAtomType = atom.iAtomType()
            if iAtomType == it:
                pos       = atom.Position()
                out.Write("  {:17.14f}  {:17.14f}  {:17.14f}  {:2}  {:2}  {:2}\n"
                    .format(*pos, 'T', 'T', 'T'))


    out.Close()

    terminate(None, usage = usage, pause = True)


def main():
    global mode
    global cifpath, poscarpath
    
    updatevars()

    print("")
    print("=============== Manage VASP input files ============")
    print("")
    print("mode: ", mode)

    if mode == 'cif2poscar':
        cif2poscar(cifpath, poscarpath)
    else:
        terminate("Error: Invalide mode [{}]".format(mode), usage = usage, pause = True)


if __name__ == "__main__":
    main()
