import sys
import os
from math import erf, erfc
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, log10, sqrt
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt

from tkcrystalbase import *


pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";


# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
lattice_parameters = [ 5.62, 5.62, 5.62, 90.0, 90.0, 90.0]
#lattice_parameters = [ 1.0, 1.0, 1.0, 90.0, 90.0, 90.0]

# Site information (atom name, site label, atomic number, atomic mass, charge, radius, color, position)
sites = [
         ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
        ,['Na', 'Na2', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.5, 0.5])]
        ,['Na', 'Na3', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.0, 0.5])]
        ,['Na', 'Na4', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.5, 0.0])]
        ,['Cl', 'Cl1', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.0, 0.0])]
        ,['Cl', 'Cl2', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.5, 0.5])]
        ,['Cl', 'Cl3', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.0, 0.5])]
        ,['Cl', 'Cl4', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.5, 0.0])]
        ]

# Minimum distance to judge idential site
rmin = 0.1

# Ewald alpha parameter
ew_alpha = 0.3

# Precision
prec = 1.0e-5

argv = sys.argv
narg = len(argv)
if narg >= 2:
    ew_alpha = float(argv[1])
if narg >= 3:
    prec = float(argv[2])


def usage():
    print("")
    print("Usage: python {} alpha prec".format(argv[0]))
    print("   ex: python {} {} {}".format(argv[0], ew_alpha, prec))
    print("")

def terminate():
    usage()
    exit()
    

def main():
    print("")
    print("Lattice parameters:", lattice_parameters)
    aij = cal_lattice_vectors(lattice_parameters)
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    inf = cal_metrics(lattice_parameters)
    gij = inf['gij']
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    volume = cal_volume(aij)
    print("Volume: {:12.4g} A^3".format(volume))

    print("")
    print("Unit cell volume: {:12.4g} A^3".format(volume))
    Raij  = cal_reciprocal_lattice_vectors(aij)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rinf  = cal_metrics(Rlatt)
    Rgij  = Rinf['gij']
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    Rvolume = cal_volume(Raij)
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))

    nsites = len(sites)
    
    print("")
    print("Ewald parameters")
    print("  alpha:", ew_alpha)
    norder = -log10(prec)
    print("  precision = {} = 10^-{}".format(prec, norder))

    rdmax     = (2.26 + 0.26 * norder) / ew_alpha
    erfc_rdmax = erfc(ew_alpha * rdmax)
    print("  RDmax = {} A, where erfc(alpha*RDmax) = {}".format(rdmax, erfc_rdmax));

    lsin  = np.empty(3, dtype = float)
    nrmax = np.empty(3, dtype = int)
    lsin[0] = sin(torad * lattice_parameters[3])
    lsin[1] = sin(torad * lattice_parameters[4])
    lsin[2] = sin(torad * lattice_parameters[5])
    nrmax[0] = int(rdmax / sqrt(gij[0][0] * lsin[1] * lsin[2])) + 1
    nrmax[1] = int(rdmax / sqrt(gij[1][1] * lsin[2] * lsin[0])) + 1
    nrmax[2] = int(rdmax / sqrt(gij[2][2] * lsin[0] * lsin[1])) + 1
    print("  nrmax:", *nrmax)

    cal_N = int(4.0 / 3.0 * pi * rdmax**3 / volume * nsites)
    print("  cal_N(real):", cal_N)

    G2max = ew_alpha**2 / pi**2 * (-log(prec))
    print("  G2max:", G2max)
    print("      exp(-pi^2 * G2max^2 / alpha^2) = ", exp(-pi**2 * G2max**2 / ew_alpha**2))
    lsin[0] = sin(torad * Rlatt[3])
    lsin[1] = sin(torad * Rlatt[4])
    lsin[2] = sin(torad * Rlatt[5])
    hgmax = np.empty(3, dtype = int)
    hgmax[0] = int(sqrt(G2max / (Rgij[0][0] * lsin[1] * lsin[2]))) + 1
    hgmax[1] = int(sqrt(G2max / (Rgij[1][1] * lsin[0] * lsin[2]))) + 1
    hgmax[2] = int(sqrt(G2max / (Rgij[2][2] * lsin[0] * lsin[1]))) + 1
    print("  hgmax:", *hgmax)

    cal_N = int(4.0 / 3.0 * pi * G2max**1.5 / Rvolume * nsites)
    print("  cal_N(reciprocal):", cal_N)

    namei, labeli, zi, Mi, qi, ri, colori, pos_i = sites[0]

    UC1 = 0.0
    for iz in range(-nrmax[2], nrmax[2]+1):
     for iy in range(-nrmax[1], nrmax[1]+1):
      for ix in range(-nrmax[0], nrmax[0]+1):
        for j in range(nsites):
            namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[j]
            rij  = distance(pos_i, pos_j + np.array([ix, iy, iz]), gij)
            
            if rij < rmin:
                 continue

            erfcar = erfc(ew_alpha * rij)
            UC1 += qj * erfcar / (rij * 1.0e-10)   # in eV

    origin = np.array([0.0, 0.0, 0.0])
    UC2 = 0.0
    Kexp = pi * pi / ew_alpha / ew_alpha
    Krec = 1.0 / pi / (volume * 1.0e-30)
#    for l in range(-hgmax[2], hgmax[2]+1):
    for l in range(0, hgmax[2]+1):
     for k in range(-hgmax[1], hgmax[1]+1):
      for h in range(-hgmax[0], hgmax[0]+1):
          G2 = distance2(origin, np.array([h, k, l]), Rgij)
          if G2 == 0.0 or G2 > G2max:
              continue

          phi_i  = pi2 * (h * pos_i[0] + k * pos_i[1] + l * pos_i[2])
          cosphi_i = cos(phi_i)
          sinphi_i = sin(phi_i)

          cossum_j = 0.0
          sinsum_j = 0.0
          for j in range(nsites):
              namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[j]

              phi_j = pi2 * (h * pos_j[0] + k * pos_j[1] + l * pos_j[2])
              cossum_j += qj * cos(phi_j)
              sinsum_j += qj * sin(phi_j)

          fcal = cosphi_i * cossum_j + sinphi_i * sinsum_j
          if l != 0:
            fcal *= 2.0
          expg = exp(-Kexp * G2) / (G2 * 1.0e+20)
          UC2 += Krec * expg * fcal

    UC3 = qi * 2.0 * (ew_alpha * 1.0e10) / sqrt(pi)

    MP = UC1 + UC2 - UC3

# Coefficient to calculate electrostatic potential
    Ke = e * e / 4.0 / pi / e0

    print("")
    print("  Madelung potential: {:12.6g} J  (= {:12.6g} + {:12.6g} - {:12.6g})".format(Ke * MP, Ke * UC1, Ke * UC2, Ke * UC3))
    print("  Madelung potential: {:12.6g} eV (= {:12.6g} + {:12.6g} - {:12.6g})".format(Ke / e * MP, Ke / e * UC1, Ke / e * UC2, Ke / e * UC3))
# Charge is represented by q0 to define Madelung constant
# Lattice parameter a is represented by q0 to define Madelung constant    
    print("  Madelung constant: {:14.8g}".format(0.5 * MP / abs(qi) * (lattice_parameters[0] * 1.0e-10)))


    terminate()


if __name__ == '__main__':
    main()
