import os
import sys
import shutil
import glob
import csv
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, pi
from scipy.interpolate import interp1d
from pprint import pprint
from matplotlib import pyplot as plt


from tklib.tkfile import tkFile
from tklib.tkutils import IsDir, IsFile, SplitFilePath
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tksci.tksci import Reduce01, Round
from tklib.tksci.tkmatrix import make_matrix1, make_matrix2, make_matrix3
from tklib.tkcrystal.tkcif import tkCIF, tkCIFData
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkatomtype import tkAtomType


#================================
# global parameters
#================================
debug = 0

ciffile = 'STD0.cif'
findvalidstructure = 1
targets = 'O'
#targets = 'O,In,Ga,Zn'
Rmin    = 0.1
Rmax    = 1.8

#=============================
# Treat argments
#=============================
def usage():
    global ciffile
    global Rmax, targets

    print("")
    print("Usage:")
    print("  (i) python {} ciffile Rmax targets".format(sys.argv[0]))
    print("      ex: python {} {} {} {}"
                    .format(sys.argv[0], ciffile, Rmax, targets))

def updatevars():
    global ciffile, extciffile
    global Rmax, targets

    ciffile = getarg     (1, ciffile)
    Rmax    = getfloatarg(2, Rmax)
    targets = getarg     (3, targets)

    header, ext = os.path.splitext(ciffile)
    filebody    = os.path.basename(header)
    extciffile  = filebody + '-extracted.cif'

def find_next(ihit, cry, sites, isite0, itargets, cluster, flag, nx, ny, nz):
    site0      = sites[isite0]
    name0      = site0.AtomNameOnly()
    x0, y0, z0 = site0.Position()

    for i1 in range(ihit+1, len(itargets)):
        if flag and flag[i1]:
           continue

#        print("  i1=", i1)
        isite1     = itargets[i1]
        site1      = sites[isite1]
        name1      = site1.AtomNameOnly()
        x1, y1, z1 = site1.Position()
        for iz in range(-nz, nz+1):
            for iy in range(-ny, ny+1):
                for ix in range(-nx, nx+1):
                    dis = cry.GetInterAtomicDistance([x0, y0, z0], [x1+ix, y1+iy, z1+iz])
#                    print("{}-{}: ({:8.4f}, {:8.4f}, {:8.4f}) - ({:8.4f}, {:8.4f}, {:8.4f})".format(isite0, i1, x0, y0, z0, x1+ix, y1+iy, z1+iz), end = '')
#                    print("  dis=", dis)
                    if Rmin <= dis <= Rmax:
#                        if len(cluster) == 1:
#                            print("site {:04d} {:2}: ({:8.4f}, {:8.4f}, {:8.4f})".format(isite0, name0, x0, y0, z0))
                        if cluster is not None:
                            cluster.append(isite1)
                        if flag is not None:
                            flag[i1] = 1
#                        print("     {:04d} {:2}: ({:8.4f}, {:8.4f}, {:8.4f})  dis={:12.4g} A".format(isite1, name1, x1+iz, y1+iy, z1+iz, dis))
                        return i1, ix, iy, iz, dis

    return None, None, None, None, None

def append_unique(list, data):
    for l in list:
        if l == data:
            print("*** repeated ***", end = '')
            return list
    else:
        list.append(data)

def find_clusters():
    global ciffile, extciffile
    global Rmax, targets

    print("CIF file          : {}".format(ciffile))
    print("Extracted CIF file: {}".format(extciffile))
    print("Rmax              : {}".format(Rmax))
    print("Target elements   : {}".format(targets))

    print("")
    print("Read [{}]".format(ciffile))

    cif = tkCIF()
    cif.debug = debug
    cifdata = cif.ReadCIF(ciffile, find_valid_structure = findvalidstructure)
    cif.Close()
#    cifdata.Print()

    cry = cifdata.GetCrystal()

    if not cry:
        terminate("Error: Could not get crystal object from infile [{}]".format(infile), usage = usage, pause = True)

    print("")
    print("==============================================")
    print("    Crystal inf")
    print("==============================================")
#    cry.PrintInf()
    a, b, c, alpha, beta, gamm = cry.LatticeParameters()
    print("cell: {:12.6g} {:12.6g} {:12.6g} A   {:12.6g} {:12.6g} {:12.6g}".format(a, b, c, alpha, beta, gamm))
    density = cry.Density()
    adensity = cry.AtomDensity()
    print("  Density: {:12.6g} g/cm-3".format(density))
    print("  Atom density: {:12.6g} cm^-3".format(adensity))

    AtomTypes = cry.AtomTypeList()
    nTypes = len(AtomTypes)
    AtomSites = cry.AtomSiteList()
    nAsymSites = len(AtomSites)
    ExpandedAtomSites = cry.ExpandedAtomSiteList()
    nAllSites = len(ExpandedAtomSites)

    print("")
    print("Atom types:")
    for i in range(nTypes):
        t = AtomTypes[i]
        typea = t.AtomType()
        typeo = t.AtomTypeOnly()
        charge = t.Charge()
        AN     = t.AtomicNumber()
        M      = t.AtomicMass()
        print("  %3d: %4s (%2s) charge=%8.4f  [Z=%3d  M=%6.4f]" % (i, typea, typeo, charge, AN, M))
    
    print("")
    print("Expanded atom sites:")
    for i in range(nAllSites):
        atom      = ExpandedAtomSites[i]
        label     = atom.Label()
        atomname  = atom.AtomName()
        atomname0 = atom.AtomNameOnly()
        charge    = atom.Charge()
        pos       = atom.Position()
        occ       = atom.Occupancy()
        id        = atom.IdAsymmetricAtomSite()
        iAtomType = atom.iAtomType()
        atomtype  = atom.AtomType()
        m         = atom.Multiplicity()
        print("  %3d: (iAtomType=%d) %5s: %-4s charge=%6.3f (%6.3f, %6.3f, %6.3f) occ=%8.4f  m=%d" 
                % (id, iAtomType, label, atomname, charge, pos[0], pos[1], pos[2], occ, m))

    nLatticeX, nLatticeY, nLatticeZ = cry.GetLatticeRange(Rmax)
    print("")
    print("  Lattice Range: {} {} {}".format(nLatticeX, nLatticeY, nLatticeZ))

    itargets = []
    for i in range(nAllSites):
        site = ExpandedAtomSites[i]
        name = site.AtomNameOnly()
        if name in targets:
            itargets.append(i)
    
    print("")
    print("Target indexes: ", itargets)

    print("")
    print("Make table witin {} A for {}...".format(Rmax, targets))
    table = []
#    itargets = itargets[:500]
    for i0 in range(len(itargets)):
        isite0     = itargets[i0]
        site0      = ExpandedAtomSites[isite0]
        name0      = site0.AtomNameOnly()
        x0, y0, z0 = site0.Position()
        print("{}({}) ".format(isite0, name0), flush = True, end = '')
#        print("i0=", i0)

        ihit = i0 + 1
        c = 0
        while 1:
            ihit, ix, iy, iz, dis = find_next(ihit, cry, ExpandedAtomSites, isite0,
                                        itargets, None, None, nLatticeX, nLatticeY, nLatticeZ)
            if ihit is None:
                break

            site1 = ExpandedAtomSites[ihit]
            name1 = site1.AtomNameOnly()

            if c == 0:
#                table.append([isite0, 0, 0, 0, name0])
                append_unique(table, [isite0, 0, 0, 0, name0])
                print("")
            print("  ihit={:04d}({:>2}) +({}, {}, {}) dis={:12.4g}".format(ihit, name1, ix, iy, iz, dis))
            table.append([ihit, ix, iy, iz, name1])
            ihit += 1
            c += 1


    print("")
    print("Build Crystal object:")
    table = sorted(table)
    print("nhit=", len(table))
#    print(table)
    extcry = tkCrystal()
    extcry.SetSampleName(cry.SampleName())
    extcry.SetCrystalName(cry.SampleName())
    extcry.SetLatticeParameters(cry.LatticeParameters())
    for i in range(len(table)):
        isite, ix, iy, iz, name = table[i]
        site = ExpandedAtomSites[isite]
        x, y, z = site.Position()
        print("site {:04d} {:>2}: ({:8.4f},{:8.4f},{:8.4f}) + ({},{},{})".format(isite, name, x, y, z, ix, iy, iz))
        extcry.AddAtomSite(name = name, pos = [x+ix, y+iy, z+iz])
    
    cry.ExpandCoordinates()
#    cry.PrintInf()

    print("Save to [{}]".format(extciffile))
    extcif = tkCIFData()
    extcif.CreateCIFFileFromCCrystal(extcry, extciffile)

    terminate(usage = usage, pause = True)


    print("")
    print("Search clusters...")
    flag = make_matrix1(len(itargets), defval = 0)
    for i0 in range(len(itargets)):
        if flag[i0]:
            continue

        isite0     = itargets[i0]
        site0      = ExpandedAtomSites[isite0]
        name0      = site0.AtomNameOnly()
        x0, y0, z0 = site0.Position()

        flag[i0] = 1
        cluster = [isite0]
        ihit = i0 + 1
        while 1:
            ihit, ix, iy, iz, dis = find_next(ihit, cry, ExpandedAtomSites, isite0,
                                        itargets, cluster, flag, nLatticeX, nLatticeY, nLatticeZ)
            if ihit is None:
                break

#            cluster.append(ihit)
            flag[ihit] = 1

            isite1     = itargets[ihit]
            site1      = ExpandedAtomSites[isite1]
            name1      = site1.AtomNameOnly()
            x1, y1, z1 = site1.Position()

            if len(cluster) == 2:
                print("site 0: {:04d} {:2}: ({:8.4f}, {:8.4f}, {:8.4f})".format(isite0, name0, x0, y0, z0))
            print("     1: {:04d} {:2}: ({:8.4f}, {:8.4f}, {:8.4f})  dis={:12.4g} A".format(isite1, name1, x1+iz, y1+iy, z1+iz, dis))

            while 1:
                ihit2, ix, iy, iz, dis = find_next(ihit, cry, ExpandedAtomSites, isite1,
                                        itargets, cluster, flag, nLatticeX, nLatticeY, nLatticeZ)

                if ihit2 is None:
                    break

                ihit = ihit2
#                cluster.append(ihit)
                flag[ihit] = 1

                isite2     = itargets[ihit]
                site2      = ExpandedAtomSites[isite2]
                name2      = site2.AtomNameOnly()
                x2, y2, z2 = site2.Position()

                print("     2: {:04d} {:2}: ({:8.4f}, {:8.4f}, {:8.4f})  dis={:12.4g} A"
                        .format(isite2, name2, x2+iz, y2+iy, z2+iz, dis))


        if len(cluster) >= 2:
            print("")

    terminate(usage = usage, pause = True)

    
def main():
    updatevars()

    print("")
    print("=============== Find clusters from CIF file ============")
    find_clusters()


if __name__ == "__main__":
    main()

