from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, sqrt
import numpy as np
from numpy import linalg as la
from pprint import pprint

pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";


# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
lattice_parameters = [ 5.62, 5.62, 5.62, 90.0, 90.0, 90.0]

# Site information (atom name, site label, atomic number, atomic mass, charge, radius, color, position)
sites = [
         ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
        ,['Na', 'Na2', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.5, 0.5])]
        ,['Na', 'Na3', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.0, 0.5])]
        ,['Na', 'Na4', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.5, 0.0])]
        ,['Cl', 'Cl1', 17, 35.4527,  +1.0, 1.4, 'blue', np.array([0.5, 0.0, 0.0])]
        ,['Cl', 'Cl2', 17, 35.4527,  +1.0, 1.4, 'blue', np.array([0.5, 0.5, 0.5])]
        ,['Cl', 'Cl3', 17, 35.4527,  +1.0, 1.4, 'blue', np.array([0.0, 0.0, 0.5])]
        ,['Cl', 'Cl4', 17, 35.4527,  +1.0, 1.4, 'blue', np.array([0.0, 0.5, 0.0])]
        ]


def reduce01(x):
    return x - int(x)

    return x

def round01(x):
    if abs(x - 1.0) < 0.0002:
        return 1.0
    if abs(x) < 0.0002:
        return 0.0
    if x > 1.0:
        return x - int(x)
    if x < 1.0:
        return x - int(x) + 1.0
    return x

def round_parameter(x, tol):
    val = tol * int( (x+0.1*tol) / tol )
    return val 


def cal_lattice_vectors(latt):
    cosa = cos(torad * latt[3])
    cosb = cos(torad * latt[4])
    cosg = cos(torad * latt[5])
    sing = sin(torad * latt[5])

    aij = np.empty([3, 3], dtype = float)
    aij[0][0] = latt[0]
    aij[0][1] = 0.0;
    aij[0][2] = 0.0;
    aij[1][0] = latt[1] * cosg
    aij[1][1] = latt[1] * sing
    aij[1][2] = 0.0;
    aij[2][0] = latt[2] * cosb
    aij[2][1] = latt[2] * (cosa - cosb * cosg) / sing
    if abs(aij[2][1]) < 1.0e-8:
        aij[2][1] = 0.0
    else:
        aij[2][1] = aij[2][1] / sing
    aij[2][2] = sqrt(latt[2] * latt[2] - aij[2][0] * aij[2][0] - aij[2][1] * aij[2][1])

    return aij

def cal_metrics(latt):
    inf = {}

    cosa = cos(torad * latt[3])
    cosb = cos(torad * latt[4])
    cosg = cos(torad * latt[5])

    aij = cal_lattice_vectors(latt)
    inf['aij'] = aij

    gij = np.empty([3, 3], dtype = float)
    for i in range(3):
        for j in range(i, 3):
            gij[i][j] = np.dot(aij[i], aij[j])
            gij[j][i] = gij[i][j]
    inf['gij'] = gij
    
    return inf

def cal_volume(aij):
    axb = np.cross(aij[0], aij[1])      # Outner product
    vol = np.dot(axb, aij[2])           # Inner product
    return vol

def cal_reciprocal_lattice_vectors(aij):
    V = cal_volume(aij)
    Ra = np.cross(aij[1], aij[2]) / V
    Rb = np.cross(aij[2], aij[0]) / V
    Rc = np.cross(aij[0], aij[1]) / V

    return [Ra, Rb, Rc]

def cal_reciprocal_lattice_parameters(Raij):
    Ra = la.norm(Raij[0])
    Rb = la.norm(Raij[1])
    Rc = la.norm(Raij[2])
    Ralpha = todeg * arccos(np.dot(Raij[1], Raij[2]) / Rb / Rc)
    Rbeta  = todeg * arccos(np.dot(Raij[2], Raij[0]) / Rc / Ra)
    Rgamma = todeg * arccos(np.dot(Raij[0], Raij[1]) / Ra / Rb)

    return [Ra, Rb, Rc, Ralpha, Rbeta, Rgamma]

def fractional_to_cartesian(pos, aij):
    x = pos[0] * aij[0][0] + pos[1] * aij[1][0] + pos[2] * aij[2][0]
    y = pos[0] * aij[0][1] + pos[1] * aij[1][1] + pos[2] * aij[2][1]
    z = pos[0] * aij[0][2] + pos[1] * aij[1][2] + pos[2] * aij[2][2]

    return x, y, z

def distance2(pos0, pos1, gij):
    dx = pos1 - pos0
#    dx = [pos1[0] - pos0[0], pos1[1] - pos0[1], pos1[2] - pos0[2]]
    r2 = gij[0][0] * dx[0]*dx[0] + gij[1][1] * dx[1]*dx[1] + gij[2][2] * dx[2]*dx[2] \
       + 2.0 * (gij[0][1] * dx[0]*dx[1] + gij[0][2] * dx[0]*dx[2] + gij[1][2] * dx[1]*dx[2])

    return r2

def distance(pos0, pos1, gij):
    dx = pos1 - pos0
#    dx = [pos1[0] - pos0[0], pos1[1] - pos0[1], pos1[2] - pos0[2]]
    r2 = gij[0][0] * dx[0]*dx[0] + gij[1][1] * dx[1]*dx[1] + gij[2][2] * dx[2]*dx[2] \
       + 2.0 * (gij[0][1] * dx[0]*dx[1] + gij[0][2] * dx[0]*dx[2] + gij[1][2] * dx[1]*dx[2])
    r = sqrt(r2)

    return r

def angle(pos0, pos1, pos2, gij):
    dis01 = distance(pos0, pos1, gij)
    if dis01 == 0.0:
        return 0.0
    dis02 = distance(pos0, pos2, gij)
    if dis02 == 0.0:
        return 0.0

    dx01 = pos1 - pos0
    dx02 = pos2 - pos0
#    dx01 = [pos1[0] - pos0[0], pos1[1] - pos0[1], pos1[2] - pos0[2]]
#    dx02 = [pos2[0] - pos0[0], pos2[1] - pos0[1], pos2[2] - pos0[2]]
    ip = gij[0][0] * dx01[0]*dx02[0] + gij[1][1] * dx01[1]*dx02[1] + gij[2][2] * dx01[2]*dx02[2] \
       + 2.0 * (gij[0][1] * dx01[0]*dx02[1] + gij[0][2] * dx01[0]*dx02[2] + gij[1][2] * dx01[1]*dx02[2])

    cosa = ip / dis01 / dis02
    angle = todeg * arccos(cosa)
    if angle > 180.0:
        angle = 360.0 - angle 

    return angle;

def configure_axis_structure(ax, xrange, yrange, zrange, fontsize = 12, legend_fontsize = 12):
    ax.tick_params(axis = 'both', which = 'major', labelsize = fontsize)
    if fontsize > 0:
        ax.set_xlabel(f'$x$', fontsize = fontsize, labelpad = -5)
        ax.set_ylabel(f'$y$', fontsize = fontsize, labelpad = -5)
        ax.set_zlabel(f'$z$', fontsize = fontsize, labelpad = 0)

# x軸、y軸、z軸の線を非表示にする
    ax.xaxis.line.set_visible(False)
    ax.yaxis.line.set_visible(False)
    ax.zaxis.line.set_visible(False)
    
    ax.grid(False)
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor('w')
    ax.yaxis.pane.set_edgecolor('w')
    ax.zaxis.pane.set_edgecolor('w')
    ax.xaxis.pane.set_alpha(0)
    ax.yaxis.pane.set_alpha(0)
    ax.zaxis.pane.set_alpha(0)
#        ax.w_xaxis.line.set_color('none')
#        ax.w_yaxis.line.set_color('none')
#        ax.w_zaxis.line.set_color('none')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    ax.set_xticklabels([]) 
    ax.set_yticklabels([]) 
    ax.set_zticklabels([]) 

    ax.set_xlim(xrange)
    ax.set_ylim(yrange)
    ax.set_zlim(zrange)
    ax.set_aspect('equal','box')
#    ax.set_box_aspect([xrange[1] - xrange[0], yrange[1] - yrange[0], zrange[1] - zrange[0]]) 

def draw_box(ax, aij, nrange, color = 'black'):
# (0,0,0) -> ax
    ax.plot([0.0, aij[0][0]], 
            [0.0, aij[0][1]], 
            [0.0, aij[0][2]], color = color)
# (0,0,0) -> ay
    ax.plot([0.0, aij[1][0]], 
            [0.0, aij[1][1]], 
            [0.0, aij[1][2]], color = color)
# (0,0,0) -> az
    ax.plot([0.0, aij[2][0]], 
            [0.0, aij[2][1]], 
            [0.0, aij[2][2]], color = color)

# ax -> ax + ay
    ax.plot([aij[0][0], aij[0][0] + aij[1][0]], 
            [aij[0][1], aij[0][1] + aij[1][1]], 
            [aij[0][2], aij[0][2] + aij[1][2]], color = color)
# ax -> ax + az
    ax.plot([aij[0][0], aij[0][0] + aij[2][0]], 
            [aij[0][1], aij[0][1] + aij[2][1]], 
            [aij[0][2], aij[0][2] + aij[2][2]], color = color)

# ay -> ay + ax
    ax.plot([aij[1][0], aij[1][0] + aij[0][0]], 
            [aij[1][1], aij[1][1] + aij[0][1]], 
            [aij[1][2], aij[1][2] + aij[0][2]], color = color)
# ay -> ay + az
    ax.plot([aij[1][0], aij[1][0] + aij[2][0]], 
            [aij[1][1], aij[1][1] + aij[2][1]], 
            [aij[1][2], aij[1][2] + aij[2][2]], color = color)

# az -> az + ax
    ax.plot([aij[2][0], aij[2][0] + aij[0][0]], 
            [aij[2][1], aij[2][1] + aij[0][1]], 
            [aij[2][2], aij[2][2] + aij[0][2]], color = color)
# az -> ax + ay
    ax.plot([aij[2][0], aij[2][0] + aij[1][0]], 
            [aij[2][1], aij[2][1] + aij[1][1]], 
            [aij[2][2], aij[2][2] + aij[1][2]], color = color)

# ax + ay -> ax + ay + az
    ax.plot([aij[0][0] + aij[1][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[0][1] + aij[1][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[0][2] + aij[1][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

# ax + az -> ax + ay + az
    ax.plot([aij[0][0] + aij[2][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[0][1] + aij[2][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[0][2] + aij[2][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

# ay + az -> ax + ay + az
    ax.plot([aij[1][0] + aij[2][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[1][1] + aij[2][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[1][2] + aij[2][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

def draw_unitcell(ax, sites, aij, nrange, color = 'black', facecolor = 'black', edgecolor = 'white', alpha = 0.7, kr = 1.0):
    draw_box(ax, aij, nrange, color)

    if sites is None: return

    for isite, site in enumerate(sites):
        name, label, z, M, q, r, color, pos = site
        pos01 = [reduce01(pos[0]), reduce01(pos[1]), reduce01(pos[2])]
        for iz in range(int(nrange[2][0]) - 1, int(nrange[2][1]) + 1):
         for iy in range(int(nrange[1][0]) - 1, int(nrange[1][1]) + 1):
          for ix in range(int(nrange[0][0]) - 1, int(nrange[0][1]) + 1):
            posn = [pos01[0] + ix, pos01[1] + iy, pos01[2] + iz]
            if    posn[0] < nrange[0][0] or nrange[0][1] < posn[0]  \
               or posn[1] < nrange[1][0] or nrange[1][1] < posn[1]  \
               or posn[2] < nrange[2][0] or nrange[2][1] < posn[2]:
                  continue

            x, y, z = fractional_to_cartesian(posn, aij)
            ax.scatter([x], [y], [z], marker = 'o', c = facecolor, edgecolors = edgecolor, alpha = alpha, s = kr *r)


def main():
    print("")
    print("Lattice parameters:", lattice_parameters)
    aij = cal_lattice_vectors(lattice_parameters)
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    inf = cal_metrics(lattice_parameters)
    gij = inf['gij']
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    print("")
    volume = cal_volume(aij)
    print("Volume: {:12.4g} A^3".format(volume))

    print("")
    print("Unit cell volume: {:12.4g} A^3".format(volume))
    Raij  = cal_reciprocal_lattice_vectors(aij)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rinf  = cal_metrics(Rlatt)
    Rgij  = Rinf['gij']
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    Rvolume = cal_volume(Raij)
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))

    print("")
    print("dis=", distance(np.array([0,0,0]), np.array([1,1,1]), gij))
    print("angle=", angle (np.array([0,0,0]), np.array([1,1,1]), np.array([1,0,0]), gij))
    print("angle=", angle (np.array([0,0,0]), np.array([1,0,0]), np.array([0,1,0]), gij))
    
    print("")
    exit()


if __name__ == '__main__':
    main()
