import sys
import os
import time
from math import erf, erfc
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, log10, sqrt
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt


from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg
from tklib.tkfile import tkFile
from tklib.tkcrystal.tkcif import tkCIF
from tklib.tkcrystal.tkcrystal import tkCrystal
import tklib.tkcrystal.tkatomtypeobject as tkAtomTypeObject
#from tklib.tkcrystal.tkatomtypeobject import tkAtomTypeObject
from tklib.tkcrystal.tkatomtype import tkAtomType
from tkcrystalbase import *


"""
Calculate Madelung potential by Ewald method
  Requirement: tkcrystalbase.py
"""


pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";


infile = None


# Structure
# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
#sites = [
#     ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
#    ]

# Minimum distance to judge idential site
rmin = 0.1

# Ewald alpha parameter
ew_alpha = 0.3

# Precision
prec = 1.0e-5


def usage():
    print("")
    print("Usage: python {} CIF_path alpha prec".format(sys.argv[0]))
    print("   ex: python {} {} {} {}".format(sys.argv[0], infile, ew_alpha, prec))
    print("")


#==========================================
# Main prgram
#==========================================
infile   = getarg( 1, infile)
ew_alpha = getfloatarg( 2, ew_alpha)
prec     = getfloatarg(3, prec)


def EWALD(sites, lattice_parameters, Rlatt, gij, Rgij, volume, Rvolume, ew_alpha, prec):
    inf = {}

    nsites = len(sites)
    
    print("")
    print("Ewald parameters")
    print("  alpha:", ew_alpha)
    norder = -log10(prec)
    print("  precision = {} = 10^-{}".format(prec, norder))

    rdmax     = (2.26 + 0.26 * norder) / ew_alpha
    erfc_rdmax = erfc(ew_alpha * rdmax)
    print("  RDmax = {} A, where erfc(alpha*RDmax) = {}".format(rdmax, erfc_rdmax));

    lsin  = np.empty(3, dtype = float)
    nrmax = np.empty(3, dtype = int)
    lsin[0] = sin(torad * lattice_parameters[3])
    lsin[1] = sin(torad * lattice_parameters[4])
    lsin[2] = sin(torad * lattice_parameters[5])
    nrmax[0] = int(rdmax / sqrt(gij[0][0] * lsin[1] * lsin[2])) + 1
    nrmax[1] = int(rdmax / sqrt(gij[1][1] * lsin[2] * lsin[0])) + 1
    nrmax[2] = int(rdmax / sqrt(gij[2][2] * lsin[0] * lsin[1])) + 1
    print("  nrmax:", *nrmax)

    cal_N = int(4.0 / 3.0 * pi * rdmax**3 / volume * nsites)
    print("  cal_N(real):", cal_N)

    G2max = ew_alpha**2 / pi**2 * (-log(prec))
    print("  G2max:", G2max)
    print("      exp(-pi^2 * G2max^2 / alpha^2) = ", exp(-pi**2 * G2max**2 / ew_alpha**2))
    lsin[0] = sin(torad * Rlatt[3])
    lsin[1] = sin(torad * Rlatt[4])
    lsin[2] = sin(torad * Rlatt[5])
    hgmax = np.empty(3, dtype = int)
    hgmax[0] = int(sqrt(G2max / (Rgij[0][0] * lsin[1] * lsin[2]))) + 1
    hgmax[1] = int(sqrt(G2max / (Rgij[1][1] * lsin[0] * lsin[2]))) + 1
    hgmax[2] = int(sqrt(G2max / (Rgij[2][2] * lsin[0] * lsin[1]))) + 1
    print("  hgmax:", *hgmax)

    cal_N = int(4.0 / 3.0 * pi * G2max**1.5 / Rvolume * nsites)
    print("  cal_N(reciprocal):", cal_N)


# Coefficient to calculate electrostatic potential
    Ke   = e * e / 4.0 / pi / e0
    Kexp = pi * pi / ew_alpha / ew_alpha
    Krec = 1.0 / pi / (volume * 1.0e-30)

    UC1_list = np.zeros(nsites)
    UC2_list = np.zeros(nsites)
    UC3_list = np.zeros(nsites)
    for isite in range(nsites):
        namei, labeli, zi, Mi, qi, ri, colori, pos_i = sites[isite]

        stime1 = time.time()
        for iz in range(-nrmax[2], nrmax[2]+1):
         for iy in range(-nrmax[1], nrmax[1]+1):
          for ix in range(-nrmax[0], nrmax[0]+1):
            for j in range(nsites):
                namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[j]
                rij  = distance(pos_i, pos_j + np.array([ix, iy, iz]), gij)
            
                if rij < rmin:
                     continue

                erfcar = erfc(ew_alpha * rij)
                UC1_list[isite] += qj * erfcar / (rij * 1.0e-10)   # in eV
        etime1 = time.time()

        origin = np.array([0.0, 0.0, 0.0])
        for l in range(0, hgmax[2]+1):
         for k in range(-hgmax[1], hgmax[1]+1):
          for h in range(-hgmax[0], hgmax[0]+1):
            G2 = distance2(origin, np.array([h, k, l]), Rgij)
            if G2 == 0.0 or G2 > G2max:
              continue

            phi_i = pi2 * (h * pos_i[0] + k * pos_i[1] + l * pos_i[2])
            cosphi_i = cos(phi_i)
            sinphi_i = sin(phi_i)

            cossum_j = 0.0
            sinsum_j = 0.0
            for j in range(nsites):
              namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[j]
              phi_j = pi2 * (h * pos_j[0] + k * pos_j[1] + l * pos_j[2])
              cossum_j += qj * cos(phi_j)
              sinsum_j += qj * sin(phi_j)

            fcal = cosphi_i * cossum_j + sinphi_i * sinsum_j
            if l != 0:
              fcal *= 2.0
            print("fcal=", fcal, cosphi_i, cossum_j, sinphi_i, sinsum_j)

            expg = exp(-Kexp * G2) / (G2 * 1.0e+20)

            UC2_list[isite] += Krec * expg * fcal
        etime2 = time.time()

        UC3_list[isite] = 2.0 * qi * (ew_alpha * 1.0e10) / sqrt(pi)
        etime3 = time.time()

    inf["time_real_space"]       = etime1 - stime1
    inf["time_reciprocal_space"] = etime2 - etime1
    inf["time_total"]            = etime3 - stime1

    inf["UC1_list"] = Ke / e * UC1_list
    inf["UC2_list"] = Ke / e * UC2_list
    inf["UC3_list"] = Ke / e * UC3_list
    MP_list = np.zeros(nsites)
    for isite in range(nsites):
        MP_list[isite] = UC1_list[isite] + UC2_list[isite] - UC3_list[isite]
    inf["MP_list"]  = Ke / e * MP_list

    return inf

def main():
    global lattice_parameters, sites

    app    = tkApplication()

    logfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    print("")
    print("=============== CIF file read test ============")
    print(f"input: {infile}")   
    print(f"log file: {logfile}")
    print(f"ew_alpha: {ew_alpha}")   
    print(f"prec: {prec}")   

    print("")
    print("Read [{}]".format(infile))
    cif = tkCIF()
    cif.debug = False
    cifdata = cif.ReadCIF(infile, find_valid_structure = True)
    cif.Close()
    if not cifdata:
        app.terminate("Error: Could not get cifdat from infile [{}]".format(infile), pause = True)

#   cifdata.print()
    cry = cifdata.GetCrystal()
#   cry.PrintInf()

    if 1:
        lattice_parameters = cry.LatticeParameters()
        aij = cry.LatticeVectors()
        gij, Rgij = cry.Metrics()
        volume = cry.Volume()

        Raij = cry.ReciprocalLatticeVectors()
        Rlatt = cry.ReciprocalLatticeParameters()
        Rvolume = cry.ReciprocalVolume()

        print("")
        print("Lattice parameters:", lattice_parameters)
        """
        print("Lattice vectors:")
        print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
        print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
        print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
        print("Metric tensor:")
        print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
        print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
        print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
        print("Unit cell volume: {:12.4g} A^3".format(volume))

        print("")
        print("Reciprocal lattice parameters:", Rlatt)
        print("Reciprocal lattice vectors:")
        print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
        print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
        print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
        print("Reciprocal lattice metric tensor:")
        print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
        print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
        print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
        print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))
        """

    sites = []
    AtomTypes = cry.AtomTypeList()
    AtomSites = cry.ExpandedAtomSiteList()
    allsites = []
    for atom in AtomSites:
        label  = atom.Label()
        name0  = atom.AtomName()
        name   = atom.AtomNameOnly()
        z      = 1 #atom.AtomicNumber()
        q      = atom.Charge()
        pos    = atom.Position()
        xc, yc, zc = cry.FractionalToCartesian(*pos)
        if q >= 0.0:
            r = 0.7
            color = 'red'
            M = 1.0
        else:
            r = 1.4
            color = 'blue'
            M = 1.0

        sites.append([name, label, z, M, q, r, color, pos])


    inf = EWALD(sites, lattice_parameters, Rlatt, gij, Rgij, volume, Rvolume, ew_alpha, prec)

    print("")
    print(f"Time for real space sum     : {inf['time_real_space']:6}")
    print(f"Time for real reciprocal sum: {inf['time_reciprocal_space']:6}")
    print(f"Total time                  : {inf['time_total']:6}")

    print("")
    print("Madelung potential (electrostatic potential):")
    MP_tot = 0.0
    for isite in range(len(sites)):
        name, label, z, M, q, r, color, pos = sites[isite]

        UC1 = inf['UC1_list'][isite]
        UC2 = inf['UC2_list'][isite]
        UC3 = inf['UC3_list'][isite]
        MP  = inf['MP_list'][isite]
        MP_tot += q * MP
        print(f"{name:4}: {label:6}: q={q:8.2g}: "
            + f"MP = {MP * e:12.6g} J  (= {UC1 * e:12.6g} + {UC2 * e:12.6g} - {UC3 * e:12.6g})")
        print(f"                          MP = {MP:12.6g} eV (= {UC1:12.6g} + {UC2:12.6g} - {UC3:12.6g})")

    MP_tot *= 0.5
    print(f"Total Madelung energy in unit cell: {MP_tot:12.6g} eV")
# Charge is represented by q0 to define Madelung constant
# Lattice parameter a is represented by q0 to define Madelung constant    
    Ke = e * e / 4.0 / pi / e0
    namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[0]
    print("")
    print("Madelung constant")
    print(f"NOTE: The a-axis length is taken as the representative atomic distance: a = {lattice_parameters[0]}")
    print(f"      The charge of the 0-th ion is taken as the representative ion charge: q = {qj}")
    print(f"NOTE: This value is in the unit cell chemical formula")
    print("       The following value must be devided by Z to get the Madeluing constant in the standard definition")
    print("Madelung constant in unit cell: {:14.8g}".format(-MP_tot / Ke * e / abs(qj) * lattice_parameters[0] * 1.0e-10))


    app.terminate(pause = True)


if __name__ == '__main__':
    main()
