import sys
import os
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, sqrt
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt

from tkcrystalbase import *


"""
Calculate Madelung potential by simple summation
  Requirement: tkcrystalbase.py
"""


pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";



# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
lattice_parameters = [ 5.62, 5.62, 5.62, 90.0, 90.0, 90.0]

# Site information (atom name, site label, atomic number, atomic mass, charge, radius, color, position)
sites = [
         ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
        ,['Na', 'Na2', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.5, 0.5])]
        ,['Na', 'Na3', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.0, 0.5])]
        ,['Na', 'Na4', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.5, 0.0])]
        ,['Cl', 'Cl1', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.0, 0.0])]
        ,['Cl', 'Cl2', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.5, 0.5])]
        ,['Cl', 'Cl3', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.0, 0.5])]
        ,['Cl', 'Cl4', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.5, 0.0])]
        ]

# r range
rmin =   0.1
rmax = 100.0
nr   = 101


# Figure configuration
figsize = (6, 6)


argv = sys.argv
narg = len(argv)
if narg >= 2:
    rmax = float(argv[1])
if narg >= 3:
    nr = int(argv[2])

rstep = (rmax - rmin) / (nr - 1)


def usage():
    print("")
    print("Usage: python {} rmax nr".format(argv[0]))
    print("   ex: python {} {} {}".format(argv[0], rmax, nr))
    print("")

def terminate():
    usage()
    exit()


def draw_box(ax, aij, nrange, color = 'black'):
# (0,0,0) -> ax
    ax.plot([0.0, aij[0][0]], 
            [0.0, aij[0][1]], 
            [0.0, aij[0][2]], color = color)
# (0,0,0) -> ay
    ax.plot([0.0, aij[1][0]], 
            [0.0, aij[1][1]], 
            [0.0, aij[1][2]], color = color)
# (0,0,0) -> az
    ax.plot([0.0, aij[2][0]], 
            [0.0, aij[2][1]], 
            [0.0, aij[2][2]], color = color)

# ax -> ax + ay
    ax.plot([aij[0][0], aij[0][0] + aij[1][0]], 
            [aij[0][1], aij[0][1] + aij[1][1]], 
            [aij[0][2], aij[0][2] + aij[1][2]], color = color)
# ax -> ax + az
    ax.plot([aij[0][0], aij[0][0] + aij[2][0]], 
            [aij[0][1], aij[0][1] + aij[2][1]], 
            [aij[0][2], aij[0][2] + aij[2][2]], color = color)

# ay -> ay + ax
    ax.plot([aij[1][0], aij[1][0] + aij[0][0]], 
            [aij[1][1], aij[1][1] + aij[0][1]], 
            [aij[1][2], aij[1][2] + aij[0][2]], color = color)
# ay -> ay + az
    ax.plot([aij[1][0], aij[1][0] + aij[2][0]], 
            [aij[1][1], aij[1][1] + aij[2][1]], 
            [aij[1][2], aij[1][2] + aij[2][2]], color = color)

# az -> az + ax
    ax.plot([aij[2][0], aij[2][0] + aij[0][0]], 
            [aij[2][1], aij[2][1] + aij[0][1]], 
            [aij[2][2], aij[2][2] + aij[0][2]], color = color)
# az -> ax + ay
    ax.plot([aij[2][0], aij[2][0] + aij[1][0]], 
            [aij[2][1], aij[2][1] + aij[1][1]], 
            [aij[2][2], aij[2][2] + aij[1][2]], color = color)

# ax + ay -> ax + ay + az
    ax.plot([aij[0][0] + aij[1][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[0][1] + aij[1][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[0][2] + aij[1][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

# ax + az -> ax + ay + az
    ax.plot([aij[0][0] + aij[2][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[0][1] + aij[2][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[0][2] + aij[2][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

# ay + az -> ax + ay + az
    ax.plot([aij[1][0] + aij[2][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[1][1] + aij[2][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[1][2] + aij[2][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

def draw_unitcell(ax, sites, aij, nrange, color = 'black'):
    draw_box(ax, aij, nrange, color)

    if sites is None:
        return

    for site in sites:
        name, label, z, M, q, r, color, pos = site
        pos01 = [reduce01(pos[0]), reduce01(pos[1]), reduce01(pos[2])]
        for iz in range(int(nrange[2][0]) - 1, int(nrange[2][1]) + 1):
         for iy in range(int(nrange[1][0]) - 1, int(nrange[1][1]) + 1):
          for ix in range(int(nrange[0][0]) - 1, int(nrange[0][1]) + 1):
            posn = [pos01[0] + ix, pos01[1] + iy, pos01[2] + iz]
            if    posn[0] < nrange[0][0] or nrange[0][1] < posn[0]  \
               or posn[1] < nrange[1][0] or nrange[1][1] < posn[1]  \
               or posn[2] < nrange[2][0] or nrange[2][1] < posn[2]:
                  continue

            x, y, z = fractional_to_cartesian(posn, aij)
            ax.scatter([x], [y], [z], marker = 'o', c = color, s = kr *r)


def main():
    print("")
    print("Lattice parameters:", lattice_parameters)
    aij = cal_lattice_vectors(lattice_parameters)
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    inf = cal_metrics(lattice_parameters)
    gij = inf['gij']
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    volume = cal_volume(aij)
    print("Volume: {:12.4g} A^3".format(volume))

    print("")
    print("Unit cell volume: {:12.4g} A^3".format(volume))
    Raij  = cal_reciprocal_lattice_vectors(aij)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rinf  = cal_metrics(Rlatt)
    Rgij  = Rinf['gij']
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    Rvolume = cal_volume(Raij)
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))

# Calculate the range of unit cells
    nxmax = int(rmax / lattice_parameters[0]) + 1
    nymax = int(rmax / lattice_parameters[1]) + 1
    nzmax = int(rmax / lattice_parameters[2]) + 1
    print("")
    print("nmax:", nxmax, nymax, nzmax)

# Calculate Madelung potential around the zero-th ion
# First store differential potential to MPdiff
    rlist  = [rmin + i * rstep for i in range(nr)]
    MPdiff = np.zeros(nr)
    name0, label0, z0, M0, q0, r0, color0, pos0 = sites[0]
    Ke = e * e / 4.0 / pi / e0                  # in MKS
    for iz in range(-nzmax, nzmax+1):
     for iy in range(-nymax, nymax+1):
      for ix in range(-nxmax, nxmax+1):
        for isite1 in range(len(sites)):
            site1 = sites[isite1]
            name1, label1, z1, M1, q1, r1, color1, pos1 = site1
            r  = distance(pos0, pos1 + np.array([ix, iy, iz]), gij)
            ir = int((r - rmin) / rstep)
            if r < rmin or ir < 0 or nr <= ir:
                 continue

            MPdiff[ir] += Ke * q1 / (r * 1.0e-10) / e   # in eV

#                print("  {:4} ({:8.4g}, {:8.4g}, {:8.4g}) - {:4} ({:8.4g}, {:8.4g}, {:8.4g}) + ({:2d}, {:2d}, {:2d}): dis = {:10.4g} A"
#                    .format(label0, pos0[0], pos0[1], pos0[2], label1, pos1[0], pos1[1], pos1[2], ix, iy, iz, dis))

    print("")
    print("r (A)      Madelung potential (eV)")
    MP = np.empty(nr)
    MP[0] = MPdiff[0]
    print("{:10.4g}   {:12.6g}".format(rlist[0], MP[0]))
    for i in range (1, len(MPdiff)):
        MP[i] = MP[i-1] + MPdiff[i]
        print("{:10.4g}   {:12.6g}".format(rlist[i], MP[i]))
    
    fig = plt.figure(figsize = figsize)
    ax = fig.add_subplot(111)

    ax.plot(rlist, MP)
    ax.set_xlabel('r / angstrom')
    ax.set_ylabel('Electrostatic potential / eV')

    plt.show()

    
    terminate()


if __name__ == '__main__':
    main()
