import os
import sys
import shutil
import glob
import csv
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tksci.tkconvolution import Gaussian, convolute_by_func
from tklib.tkcrystal.tkcif import tkCIF
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType


#================================
# global parameters
#================================
debug = 0

# mode: atom|site|cn
#mode = 'site'
mode = 'cn'

ciffile = 'IGZO.cif'
rdfcsvfile   = None
cnrdfcsvfile = None

single = 1
findvalidstructure = 1

Rmin = 0.1   # A
Rmax = 5.0   # A
nR   = 1001
Rstep = None

#nCN = None
nCN = 6

# Width of the Gaussian-type convolution function. Applied only for RDF
# wConv = 0 will not perfom convolution
wConv = 0.0

#=============================
# figure configration
#=============================
figuresize = (5, 8)
fontsize = 12
legend_fontsize = 8


#=============================
# Treat argments
#=============================
def usage():
    global mode
    global ciffile
    global nCN, Rmin, Rmax, nR, wConv

    print("")
    print("Usage:")
    print("  (i) python {} mode ciffile Rmin Rmax nR nCN wConv".format(sys.argv[0]))
    print("      mode: [atom|site|cn]")
    print("      ex: python {} {} {} {} {} {} {} {}"
                    .format(sys.argv[0], mode, ciffile, Rmin, Rmax, nR, nCN, wConv))

def updatevars():
    global mode
    global ciffile, rdfcsvfile
    global nCN, Rmin, Rmax, nR, Rstep, wConv

    argv = sys.argv
#    if len(argv) == 1:
#        terminate(usage = usage)

    mode    = getarg     (1, mode)
    ciffile = getarg     (2, ciffile)
    Rmin    = getfloatarg(3, Rmin)
    Rmax    = getfloatarg(4, Rmax)
    nR      = getintarg  (5, nR)
    nCN     = getarg     (6, nCN)
    wConv   = getfloatarg(7, wConv)
    
    a = nCN.split(',')
    nCN = pint(a[0])

    header, ext   = os.path.splitext(ciffile)
    filebody      = os.path.basename(header)
#    ciffile       = filebody + '.cif'
    rdfcsvfile    = filebody + '-RDF.csv'
    cnrdfcsvfile  = filebody + '-CNRDF.csv'
    
    Rstep = Rmax / (nR - 1)


#=============================
# other functions
#=============================
def savecsv(outfile, header, datalist):
    try: 
        print("Write to [{}]".format(outfile))
        f = open(outfile, 'w')
    except:
#    except IOError:
        print("Error: Can not write to [{}]".format(outfile))
    else:
        fout = csv.writer(f, lineterminator='\n')
        fout.writerow(header)
#        fout.writerows(data)
        for i in range(0, len(datalist[0])):
            a = []
            for j in range(len(datalist)):
                a.append(datalist[j][i])
            fout.writerow(a)
        f.close()

def read_csv(infile, xmin = None, xmax = None, delimiter = ','):
    print("xrange=", xmin, xmax)
    data = []
    try:
        infp = open(infile, "r")
        f = csv.reader(infp, delimiter = delimiter)
        header = next(f)
        print("header=", header)
        for j in range(len(header)):
            data.append([])

        for row in f:
            x = pfloat(row[0])
            if xmin is not None and xmin <= x <= xmax:
                y = pfloat(row[1])
                data[0].append(x)
                data[1].append(y)
    except:
        terminate("Error in read_csv: Can not read [{}]".format(infile), usage = usage)

    return header, data[0], data[1]


def read_xyzfile(cfgfile, xyzfile):
    cfg = tkFile(cfgfile, 'r')
    if not cfg:
        terminate("Error in read_xyzfile: Can not read [{}]".format(cfgfile), usage = usage)

    line        = cfg.ReadLine()
    sample_name = cfg.ReadLine().strip()
    print("sample name: [{}]".format(sample_name))

    line = cfg.SkipTo("Defining vectors")
    aij = np.empty([3, 3])
    aij[0] = cfg.ReadLine().split()
    aij[1] = cfg.ReadLine().split()
    aij[2] = cfg.ReadLine().split()

    cfg.Close()

    for i in range(3):
        for j in range(3):
            aij[i][j] = pfloat(aij[i][j])

#    print("")
#    print("Lattice vectors:")
#    for i in range(3):
#        print("  {:16.12f}  {:16.12f}  {:16.12f}".format(*aij[i]))

    xyz = tkFile(xyzfile, 'r')
    if not xyz:
        terminate("Error in read_xyzfile: Can not read [{}]".format(xyzfile), usage = usage)

    nsites = pint(xyz.ReadLine())
    print("nsites = ", nsites)

# blank line
    line = xyz.ReadLine()

    site = []
    idx = 0
    for i in range(nsites):
        idx += 1
        line = xyz.ReadLine()
        if not line:
            break

        name, xc, yc, zc = line.split()
        xc = pfloat(xc)
        yc = pfloat(yc)
        zc = pfloat(zc)
#        print("{:04d} {:2} ({:8.4f}, {:8.4f}, {:8.4f})".format(idx, name, xc, yc, zc))
        site.append([name, xc, yc, zc])

    xyz.Close()

    print("")
    print("Build Crystal object:")
    cry = tkCrystal()
    cry.SetSampleName(sample_name)
    cry.SetCrystalName(sample_name)
    cry.SetLatticeVectors(aij)
    latt = cry.LatticeParameters()
    for i in range(len(site)):
        x, y, z = cry.CartesianToFractional(site[i][1], site[i][2], site[i][3])
        cry.AddAtomSite(name = site[i][0], pos = [x, y, z])

    cry.ExpandCoordinates()

    cry.PrintInf()

    return cry


def CalRDFi(cry, Rmin, Rmax, nR, atom0, iSite, normalize = True):
    print("Calculating RDF for  site #{}:".format(iSite), end = '')

    Rstep = Rmax / (nR - 1)
    xR  = [Rstep * k for k in range(nR)]
    nLatticeX, nLatticeY, nLatticeZ = cry.GetLatticeRange(Rmax)

    AtomTypes  = cry.AtomTypeList()
    nTypes     = len(AtomTypes)
    AsymSites  = cry.AtomSiteList()
    nAsymSites = len(AsymSites)
    AllSites   = cry.ExpandedAtomSiteList()
    nSites     = len(AllSites)

    name0      = atom0.AtomNameOnly()
    iAtomType0 = atom0.iAtomType()
    x0, y0, z0 = atom0.Position(1)
    occ        = atom0.Occupancy()
    print(" {:<4}[{}] ({:8.4f}, {:8.4f}, {:8.4f}) occ={:8.4f}".format(name0, iAtomType0, x0, y0, z0, occ))

    RDF = make_matrix2(nTypes, nR, 0.0)
    for i in range(nSites):
        atom1      = AllSites[i]
        name1      = atom1.AtomNameOnly()
        iAtomType1 = atom1.iAtomType()
        x1, y1, z1 = atom1.Position(1)
        occ1       = atom1.Occupancy()
        for iz in range(-nLatticeZ, nLatticeZ+1):
            for iy in range(-nLatticeY, nLatticeY+1):
                for ix in range(-nLatticeX, nLatticeX+1):
                    dis = cry.GetInterAtomicDistance([x0, y0, z0], [x1+ix, y1+iy, z1+iz])
#                    print("({:8.4f}, {:8.4f}, {:8.4f}) - ({:8.4f}, {:8.4f}, {:8.4f})".format(x0, y0, z0, x1+ix, y1+iy, z1+iz), end = '')
#                    print("  dis=", dis)
                    if dis < Rmin or Rmax < dis:
                        continue;

                    idx = int(dis / Rstep)
                    RDF[iAtomType1][idx] += occ1

    if normalize:
        for itype in range(nTypes):
            for iR in range(nR):
                 RDF[itype][iR] /= Rstep
    
    return xR, RDF


def CalculateCNRDFs(cry, Rmin, Rmax, nR, nCN):
    print("")
    print("Calculate Coordination number RDFs and RCNs.")

    AtomTypes  = cry.AtomTypeList()
    nTypes     = len(AtomTypes)
    AsymSites  = cry.AtomSiteList()
    nAsymSites = len(AsymSites)
    AllSites   = cry.ExpandedAtomSiteList()
    nSites     = len(AllSites)

#RDF[iSite][iAtomType1][iR]
    RDF0 = np.array(make_matrix1(nAsymSites))
    RDFs = np.array(make_matrix2(nAsymSites, nR, 0.0))
    for isite in range(nAsymSites):
        atom  = AsymSites[isite]
        xR, RDF0 = CalRDFi(cry, Rmin, Rmax, nR, atom, isite, normalize = False)
        for itype in range(nTypes):
            for iR in range(nR):
                RDFs[isite][iR] += RDF0[itype][iR]

    RCNs = np.array(make_matrix2(nAsymSites, nR, 0.0))
    for isite in range(nAsymSites):
        for itype in range(nTypes):
            for iR in range(1, nR):
                RCNs[isite][iR] = RCNs[isite][iR-1] + RDFs[isite][iR]

    CNRDFs = make_matrix3(nTypes, nCN+1, nR, defval = 0.0)
    CNRCNs = make_matrix3(nTypes, nCN+1, nR, defval = 0.0)
    for isite in range(nAsymSites):
        atom  = AsymSites[isite]
        name  = atom.AtomName()
        itype = atom.iAtomType()
        mult  = atom.Multiplicity()
        mult *= atom.Occupancy()
        icn   = 1
        for iR in range(nR):
            print("n=", nAsymSites, nCN+1, nR)
            print("i=", isite, iR, icn, mult)
            
            if RCNs[isite][iR] >= icn:
                CNRDFs[itype][icn][iR] += mult

                icn += 1
                if icn > nCN:
                    break

    for itype in range(nTypes):
        for icn in range(1, nCN+1):
            for iR in range(1, nR):
                CNRCNs[itype][icn][iR] = CNRCNs[itype][icn][iR-1] + CNRDFs[itype][icn][iR]

    return xR, CNRDFs, CNRCNs


def CalculateRDFs(cry, Rmin, Rmax, nR, isprint = True):
    print("")
    print("Calculate RDFs.")

    AtomTypes  = cry.AtomTypeList()
    nTypes     = len(AtomTypes)
    AsymSites  = cry.AtomSiteList()
    nAsymSites = len(AsymSites)
    AllSites   = cry.ExpandedAtomSiteList()
    nSites     = len(AllSites)

#RDF[iSite][iAtomType1][iR]
    RDF = make_matrix1(nAsymSites)

#    nAsymSites = 10
    for isite in range(nAsymSites):
        atom0    = AsymSites[isite]
        xR, RDF0 = CalRDFi(cry, Rmin, Rmax, nR, atom0, isite)
        RDF[isite] = RDF0

    if mode == 'atom' or mode == 'cn':
        RDFs = np.array(make_matrix2(nTypes, nR, 0.0))
        mult = make_matrix1(nTypes, 0)
        for isite in range(nAsymSites):
            atom1    = AsymSites[isite]
            itype1   = atom1.iAtomType()
            mult[itype1] += 1
            for itype in range(nTypes):
                for iR in range(nR):
                    RDFs[itype1][iR] += RDF[isite][itype][iR]
        print("Atom type multiplicity:")
        for itype in range(nTypes):
            type = AtomTypes[itype]
            name = type.AtomType()
            print("  {}: {}: {}".format(itype, name, mult[itype]))

        for itype in range(nTypes):
            RDFs[itype] /= mult[itype]
    
        RCNs = np.array(make_matrix2(nTypes, nR, 0.0))
        for itype in range(nTypes):
            for iR in range(1, nR):
                RCNs[itype][iR] = RCNs[itype][iR-1] + RDFs[itype][iR] * Rstep
    elif mode == 'site':
        RDFs = np.array(make_matrix2(nAsymSites, nR, 0.0))
        mult = make_matrix1(nAsymSites, 0)
        for isite in range(nAsymSites):
            atom1    = AsymSites[isite]
            itype1   = atom1.iAtomType()
            mult[isite] += 1
            for itype in range(nTypes):
                for iR in range(nR):
                    RDFs[isite][iR] += RDF[isite][itype][iR]
        print("Atom site multiplicity:")
        for isite in range(nAsymSites):
            atom = AsymSites[isite]
            name = atom.AtomName()
            print("  {}: {}: {}".format(isite, name, mult[isite]))

        for isite in range(nAsymSites):
            RDFs[isite] /= mult[isite]
    
        RCNs = np.array(make_matrix2(nAsymSites, nR, 0.0))
        for isite in range(nAsymSites):
            for iR in range(1, nR):
                 RCNs[isite][iR] = RCNs[isite][iR-1] + RDFs[isite][iR] * Rstep
    else:
        terminate("Error: Invalid mode [{}]".format(mode), usage = usage)

    return xR, RDFs, RCNs


def rdf():
    global mode, ciffile
    global nCN, Rmin, Rmax, nR, wConv

    print("mode          : {}".format(mode))
    print("CIF file      : {}".format(ciffile))
    print("RDF csv file  : {}".format(rdfcsvfile))
    print("CNRDF csv file: {}".format(cnrdfcsvfile))
    print("  single            : {}".format(single))
    print("  findvalidstructure: {}".format(findvalidstructure))
    print("R range : min {}, max {} A, {} A step, nR = {}".format(Rmin, Rmax, Rstep, nR))
    if mode == 'cn':
        print("nCN     : {}".format(nCN))
    print("wConv       : {}".format(wConv))
    if wConv == 0.0:
        print("  Convolution will not be applied.")
    
    print("")
    print("Read [{}]".format(ciffile))

#    if not tkutils.IsFile(infile):
#        terminate("Error in rdf(): Invalid infile [{}]".format(infile), usage = usage)
#    cry = read_xyzfile(cfgfile, xyzfile)

    if not IsFile(ciffile):
        terminate("Error in rdf(): Invalid CIF file [{}]".format(ciffile), usage = usage)
    cif = tkCIF()
    cif.debug = debug
    cifdata = cif.ReadCIF(ciffile, find_valid_structure = findvalidstructure)
    cif.Close()
#    cifdata.Print()

    cry = cifdata.GetCrystal()

    if not cry:
        terminate("Error: Could not get crystal object from infile [{}]".format(infile), usage = usage)

    print("")
    print("==============================================")
    print("    Crystal inf")
    print("==============================================")
#    cry.PrintInf()
    a, b, c, alpha, beta, gamm = cry.LatticeParameters()
    print("cell: {:12.6g} {:12.6g} {:12.6g} A   {:12.6g} {:12.6g} {:12.6g}".format(a, b, c, alpha, beta, gamm))
    density = cry.Density()
    adensity = cry.AtomDensity()
    print("  Density: {:12.6g} g/cm-3".format(density))
    print("  Atom density: {:12.6g} cm^-3".format(adensity))

    print("")
    print("Atom types:")
    AtomTypes = cry.AtomTypeList()
    nTypes = len(AtomTypes)
    for i in range(nTypes):
        t = AtomTypes[i]
        typea = t.AtomType()
        typeo = t.AtomTypeOnly()
        charge = t.Charge()
        AN     = t.AtomicNumber()
        M      = t.AtomicMass()
        print("  %3d: %4s (%2s) charge=%8.4f  [Z=%3d  M=%6.4f]" % (i, typea, typeo, charge, AN, M))
    
    print("")
    print("Atom sites:")
    AtomSites = cry.AtomSiteList()
    nAsymSites = len(AtomSites)
    for i in range(nAsymSites):
        atom      = AtomSites[i]
        label     = atom.Label()
        atomname  = atom.AtomName()
        atomname0 = atom.AtomNameOnly()
        charge    = atom.Charge()
        pos       = atom.Position()
        occ       = atom.Occupancy()
        id        = atom.IdAsymmetricAtomSite()
        iAtomType = atom.iAtomType()
        atomtype  = atom.AtomType()
        m         = atom.Multiplicity()
        print("  %3d: (iAtomType=%d) %5s: %-4s charge=%6.3f (%6.3f, %6.3f, %6.3f) occ=%8.4f  m=%d" 
                % (id, iAtomType, label, atomname, charge, pos[0], pos[1], pos[2], occ, m))

    print("")
    print("Expanded atom sites:")
    ExpandedAtomSites = cry.ExpandedAtomSiteList()
    for atom in ExpandedAtomSites:
        label     = atom.Label()
        atomname  = atom.AtomName()
        atomname0 = atom.AtomNameOnly()
        charge    = atom.Charge()
        pos       = atom.Position()
        occ       = atom.Occupancy()
        id        = atom.IdAsymmetricAtomSite()
        iAtomType = atom.iAtomType()
        atomtype  = atom.AtomType()
        m         = atom.Multiplicity()
        print("  %3d: (iAtomType=%d) %5s: %-4s charge=%6.3f (%6.3f, %6.3f, %6.3f) occ=%8.4f  m=%d" 
                % (id, iAtomType, label, atomname, charge, pos[0], pos[1], pos[2], occ, m))

    nLatticeX, nLatticeY, nLatticeZ = cry.GetLatticeRange(Rmax)
    print("")
    print("  Lattice Range: {} {} {}".format(nLatticeX, nLatticeY, nLatticeZ))

# RDFs[iType][iR]
    xR, RDFs, RCNs = CalculateRDFs(cry, Rmin, Rmax, nR)
    if mode == 'cn':
        xR, CNRDFs, CNRCNs = CalculateCNRDFs(cry, Rmin, Rmax, nR, nCN)

    if wConv != 0.0:
        print("")
        print("Convoluting RDF(r) by Gaussian function with whalf = {} A".format(wConv))
        for i in range(len(RDFs)):
            RDFs[i] = convolute_by_func(xR, RDFs[i], Gaussian, xR[i], wConv)
        if mode == 'cn':
            for itype in range(len(CNRCNs)):
                for icn in range(1, len(CNRCNs[itype])):
                    CNRDFs[itype][icn] = convolute_by_func(xR, CNRDFs[itype][icn], Gaussian, xR[i], wConv)

# Save to csv file
    print("")
    print("Save to [{}]".format(rdfcsvfile))
    label = ["R(A)"]
    data  = [xR]
    if mode == 'atom':
        for i in range(0, nTypes):
            type = AtomTypes[i]
            name = type.AtomType()
            label.append("RDF({}:{})".format(name, i))
            data.append(RDFs[i])
        for i in range(0, nTypes):
            type = AtomTypes[i]
            name = type.AtomType()
            label.append("RCN({}:{})".format(name, i))
            data.append(RCNs[i])
        savecsv(rdfcsvfile, label, data)
    elif mode == 'site':
        for i in range(0, nAsymSites):
            atom = AtomSites[i]
            name = atom.AtomName()
            label.append("RDF({}:{})".format(name, i))
            data.append(RDFs[i])
        for i in range(0, nAsymSites):
            atom = AtomSites[i]
            name = atom.AtomName()
            label.append("RCF({}:{})".format(name, i))
            data.append(RCNs[i])
        savecsv(rdfcsvfile, label, data)
    elif mode == 'cn':
        for i in range(0, nTypes):
            type = AtomTypes[i]
            name = type.AtomType()
            label.append("RDF({}:{})".format(name, i))
            data.append(RDFs[i])
        for i in range(0, nTypes):
            type = AtomTypes[i]
            name = type.AtomType()
            label.append("RCN({}:{})".format(name, i))
            data.append(RCNs[i])
        savecsv(rdfcsvfile, label, data)

        for itype in range(nTypes):
            type = AtomTypes[itype]
            name = type.AtomType()
            label = ["R(A)"]
            data  = [xR]
            for icn in range(1, nCN+1):
                label.append("CNRDF(CN>={})({}({}):".format(icn, name, itype))
                data.append(CNRDFs[i][icn])
            for icn in range(1, nCN+1):
                label.append("CNRCNCN>={}))({}({})".format(icn, name, itype))
                data.append(CNRCNs[i][icn])
            header, ext = os.path.splitext(ciffile)
            filebody    = os.path.basename(header)
            fname = filebody + '-CNRDF-{}.csv'.format(name)                    
            savecsv(fname, label, data)


#=============================
# Plot graphs
#=============================
    if mode == 'cn':
        fig = plt.figure(figsize = [figuresize[0] * nTypes, figuresize[1]])
        ncol = nTypes + 1
    else:
        fig = plt.figure(figsize = figuresize)
        ncol = 1

    ax = make_matrix2(2, ncol)
    for j in range(ncol):
        ax[0][j] = fig.add_subplot(2, ncol, j+1)
        ax[1][j] = fig.add_subplot(2, ncol, j+ncol+1)
    
    if mode == 'atom':
        for i in range(0, nTypes):
            type = AtomTypes[i]
            name = type.AtomType()
            ax[0][0].plot(xR, RDFs[i], label = "RDF({}({}))".format(name, i), linewidth = 0.5)
            ax[1][0].plot(xR, RCNs[i], label = "RCN({}({}))".format(name, i), linewidth = 0.5)
    elif mode == 'site':
        for i in range(0, nAsymSites):
            atom = AtomSites[i]
            name = atom.AtomName()
            ax[0][0].plot(xR, RDFs[i], label = "RDF({}({}))".format(name, i), linewidth = 0.5)
            ax[1][0].plot(xR, RCNs[i], label = "RCN({}({}))".format(name, i), linewidth = 0.5)
    elif mode == 'cn':
        for itype in range(0, nTypes):
            type = AtomTypes[itype]
            name = type.AtomType()
            ax[0][0].plot(xR, RDFs[itype], label = "RDF({}({}))".format(name, itype), linewidth = 0.5)
            ax[1][0].plot(xR, RCNs[itype], label = "RCN({}({}))".format(name, itype), linewidth = 0.5)
            for icn in range(1, nCN+1):
                if icn == 1:
                    rdflabel = 'CNRDF {}({}):'.format(name, itype)
                    rcnlabel = 'CNRCN {}({}):'.format(name, itype)
                else:
                    rdflabel = ''
                    rcnlabel = ''
                ax[0][itype+1].plot(xR, CNRDFs[itype][icn], label = "{}CN$\geqq${}".format(rdflabel, icn), linewidth = 0.5)
                ax[1][itype+1].plot(xR, CNRCNs[itype][icn], label = "{}CN$\geqq${}".format(rcnlabel, icn), linewidth = 0.5)
    else:
        terminate("Error: Invalid mode [{}]".format(mode), usage = usage)


    ax[0][0].set_ylabel("$dN(r\leq$$R)/dR$", fontsize = fontsize)
    ax[0][0].set_xlim([0.0, Rmax])

    ax[1][0].set_ylabel("$N(r\leq$$R)$", fontsize = fontsize)
    ax[1][0].set_xlim([0.0, Rmax])
    ax[1][0].set_ylim([0, 12])

    for i in range(len(ax)):
        for j in range(len(ax[i])):
            ax[i][j].set_xlabel("$R (\AA)$", fontsize = fontsize)
            ax[i][j].set_xlim([0.0, Rmax])
            ax[i][j].legend(fontsize = legend_fontsize)

    plt.tight_layout()

    print("")
    print("Close the graph window to proceed")
    plt.show()

    """
    plt.pause(0.1)

    print("")
    print("Press ENTER to exit>>", end = '')
    input()
    """

    terminate(usage = usage)


def main():
    updatevars()

    print("")
    print("=============== Calculate RDF from xyz structure file ============")
    rdf()


if __name__ == "__main__":
    main()

