import sys
import os
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, sqrt
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt

from tkcrystalbase import *


pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";


# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
lattice_parameters = [ 5.62, 5.62, 5.62, 90.0, 90.0, 90.0]
#lattice_parameters = [ 1.0, 1.0, 1.0, 90.0, 90.0, 90.0]

# Site information (atom name, site label, atomic number, atomic mass, charge, radius, color, position)
sites = [
         ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
        ,['Na', 'Na2', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.5, 0.5])]
        ,['Na', 'Na3', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.0, 0.5])]
        ,['Na', 'Na4', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.5, 0.0])]
        ,['Cl', 'Cl1', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.0, 0.0])]
        ,['Cl', 'Cl2', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.5, 0.5])]
        ,['Cl', 'Cl3', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.0, 0.5])]
        ,['Cl', 'Cl4', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.5, 0.0])]
        ]

# Range of unit cells to regard as in the unit cell
nrange0 = [[-0.1, 1.1], [-0.1, 1.1], [-0.1, 1.1]]

# Range of unit cells to calculate Madelung potential
nmax = 1

# Minimum distance to judge idential site
rmin = 0.1


argv = sys.argv
narg = len(argv)
if narg >= 2:
    nmax = int(argv[1])


def usage():
    print("")
    print("Usage: python {} rmax".format(argv[0]))
    print("   ex: python {} {}".format(argv[0], nmax))
    print("")

def terminate():
    usage()
    exit()


def main():
    print("")
    print("Lattice parameters:", lattice_parameters)
    aij = cal_lattice_vectors(lattice_parameters)
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    inf = cal_metrics(lattice_parameters)
    gij = inf['gij']
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    volume = cal_volume(aij)
    print("Volume: {:12.4g} A^3".format(volume))

    print("")
    print("Unit cell volume: {:12.4g} A^3".format(volume))
    Raij  = cal_reciprocal_lattice_vectors(aij)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rinf  = cal_metrics(Rlatt)
    Rgij  = Rinf['gij']
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    Rvolume = cal_volume(Raij)
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))

# Count multiplicity in the unit cell
    mult = np.zeros(len(sites), dtype = int)
    extpos = []
    for isite in range(len(sites)):
        name, label, z, M, q, r, color, pos = sites[isite]
        pos01 = [reduce01(pos[0]), reduce01(pos[1]), reduce01(pos[2])]
        for iz in range(int(nrange0[2][0]) - 1, int(nrange0[2][1]) + 1):
         for iy in range(int(nrange0[1][0]) - 1, int(nrange0[1][1]) + 1):
          for ix in range(int(nrange0[0][0]) - 1, int(nrange0[0][1]) + 1):
            posn = [pos01[0] + ix, pos01[1] + iy, pos01[2] + iz]
            if    posn[0] < nrange0[0][0] or nrange0[0][1] < posn[0]  \
               or posn[1] < nrange0[1][0] or nrange0[1][1] < posn[1]  \
               or posn[2] < nrange0[2][0] or nrange0[2][1] < posn[2]:
                  continue

            mult[isite] += 1
            extpos.append([name, label, z, M, q, r, [pos01[0] + ix, pos01[1] + iy, pos01[2] + iz], isite])

    print("")
    print("Site information (all sites in the unit cell with the range:", nrange0, ")")
    qtot = 0.0
    for isite in range(len(extpos)):
        name, label, z, M, q, r, pos, isite0 = extpos[isite]
        m = mult[isite0]
        w = 1.0 / m
        print("  {:4} ({:8.3g}, {:8.3g}, {:8.3g}) q={:5.3g} mult={:2d} weight={:8.4g}".format(label, *pos, q, m, w))
        qtot += q * w
    print("qtot=", qtot)

    print("")
    print("Calculate Madelung potential around the zero-th ion by Evjen method")
    print("  nmax:", nmax)
    name0, label0, z0, M0, q0, r0, pos0, isite = extpos[0]
    print("  Origin: {} ({}, {}, {})".format(label0, *pos0))
    MP = 0.0
    for iz in range(-nmax, nmax):
     for iy in range(-nmax, nmax):
      for ix in range(-nmax, nmax):
        for isite1 in range(len(extpos)):
            extsite1 = extpos[isite1]
            name1, label1, z1, M1, q1, r1, pos1, idxsite = extsite1
            w = 1.0 / mult[idxsite]

            r = distance(pos0, pos1 + np.array([ix, iy, iz]), gij)
            if r < rmin:
                 continue

            p = q1 / (r * 1.0e-10) * w
            MP += p
#            print(" ({:2d},{:2d},{:2d})+({:8.3g}, {:8.3g}, {:8.3g}) d={:8.3g} q={:2g} w={:8.4g} p={:8.4g} MP={:8.4g}".format(ix, iy, iz, *pos1, r, q1, w, p, MP))

# Coefficient to calculate electrostatic potential
    Ke = e * e / 4.0 / pi / e0

    print("")
    print("  Madelung potential: {:12.6g} J".format(Ke * MP))
    print("  Madelung potential: {:12.6g} eV".format(Ke / e * MP))
# Charge is represented by q0 to define Madelung constant
# Lattice parameter a is represented by q0 to define Madelung constant    
    print("  Madelung constant: {:14.8g}".format(0.5 * MP / abs(q0) * (lattice_parameters[0] * 1.0e-10)))


    terminate()


if __name__ == '__main__':
    main()
