"""
概要: マデルングポテンシャルを単純総和法で計算するスクリプト。
詳細説明:
    このスクリプトは、指定された結晶構造（格子定数とサイト情報）に基づき、
    単純総和法を用いてマデルングポテンシャルを計算します。
    計算されたポテンシャルは、中心イオンからの距離の関数としてプロットされます。
    tkcrystalbase.pyモジュールに依存します。

関連リンク:
    crystal_MP_simple_usage
"""
import sys
import os
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, sqrt
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt

from tkcrystalbase import *


pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";



# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
lattice_parameters = [ 5.62, 5.62, 5.62, 90.0, 90.0, 90.0]

# Site information (atom name, site label, atomic number, atomic mass, charge, radius, color, position)
sites = [
         ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
        ,['Na', 'Na2', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.5, 0.5])]
        ,['Na', 'Na3', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.0, 0.5])]
        ,['Na', 'Na4', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.5, 0.0])]
        ,['Cl', 'Cl1', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.0, 0.0])]
        ,['Cl', 'Cl2', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.5, 0.5])]
        ,['Cl', 'Cl3', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.0, 0.5])]
        ,['Cl', 'Cl4', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.5, 0.0])]
        ]

# r range
rmin =   0.1
rmax = 100.0
nr   = 101


# Figure configuration
figsize = (6, 6)

rstep = (rmax - rmin) / (nr - 1)


def usage():
    """
    概要: スクリプトの正しい使用方法を表示します。
    詳細説明: コマンドライン引数rmaxとnrの指定方法をユーザーに示します。
    :returns: なし
    :rtype: None
    """
    print("")
    print("Usage: python {} rmax nr".format(argv[0]))
    print("   ex: python {} {} {}".format(argv[0], rmax, nr))
    print("")

def terminate():
    """
    概要: エラー発生時に使用方法を表示してスクリプトを終了します。
    詳細説明: usage関数を呼び出して使用方法を出力した後、システムを終了します。
    :returns: なし (スクリプトが終了するため)
    :rtype: None
    """
    usage()
    exit()


def draw_box(ax, aij, nrange, color = 'black'):
    """
    概要: 結晶の単位格子境界を3Dプロットに描画します。
    詳細説明: 単位格子の辺を黒線（または指定された色）で描画します。
              この関数はdraw_unitcellから呼び出されますが、nrangeは現在の実装では使用されていません。
    引数:
        :param ax: matplotlibの3D軸オブジェクト。
        :type ax: matplotlib.axes._subplots.Axes3DSubplot
        :param aij: (3, 3)のndarray、格子ベクトルa, b, cを表す。
        :type aij: numpy.ndarray
        :param nrange: 描画する単位格子の範囲。[[xmin, xmax], [ymin, ymax], [zmin, zmax]]の形式。（現状未使用）
        :type nrange: list of list of float
        :param color: 描画する線の色。デフォルトは'black'。
        :type color: str
    戻り値:
        :returns: なし
        :rtype: None
    """
# (0,0,0) -> ax
    ax.plot([0.0, aij[0][0]], 
            [0.0, aij[0][1]], 
            [0.0, aij[0][2]], color = color)
# (0,0,0) -> ay
    ax.plot([0.0, aij[1][0]], 
            [0.0, aij[1][1]], 
            [0.0, aij[1][2]], color = color)
# (0,0,0) -> az
    ax.plot([0.0, aij[2][0]], 
            [0.0, aij[2][1]], 
            [0.0, aij[2][2]], color = color)

# ax -> ax + ay
    ax.plot([aij[0][0], aij[0][0] + aij[1][0]], 
            [aij[0][1], aij[0][1] + aij[1][1]], 
            [aij[0][2], aij[0][2] + aij[1][2]], color = color)
# ax -> ax + az
    ax.plot([aij[0][0], aij[0][0] + aij[2][0]], 
            [aij[0][1], aij[0][1] + aij[2][1]], 
            [aij[0][2], aij[0][2] + aij[2][2]], color = color)

# ay -> ay + ax
    ax.plot([aij[1][0], aij[1][0] + aij[0][0]], 
            [aij[1][1], aij[1][1] + aij[0][1]], 
            [aij[1][2], aij[1][2] + aij[0][2]], color = color)
# ay -> ay + az
    ax.plot([aij[1][0], aij[1][0] + aij[2][0]], 
            [aij[1][1], aij[1][1] + aij[2][1]], 
            [aij[1][2], aij[1][2] + aij[2][2]], color = color)

# az -> az + ax
    ax.plot([aij[2][0], aij[2][0] + aij[0][0]], 
            [aij[2][1], aij[2][1] + aij[0][1]], 
            [aij[2][2], aij[2][2] + aij[0][2]], color = color)
# az -> ax + ay
    ax.plot([aij[2][0], aij[2][0] + aij[1][0]], 
            [aij[2][1], aij[2][1] + aij[1][1]], 
            [aij[2][2], aij[2][2] + aij[1][2]], color = color)

# ax + ay -> ax + ay + az
    ax.plot([aij[0][0] + aij[1][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[0][1] + aij[1][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[0][2] + aij[1][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

# ax + az -> ax + ay + az
    ax.plot([aij[0][0] + aij[2][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[0][1] + aij[2][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[0][2] + aij[2][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

# ay + az -> ax + ay + az
    ax.plot([aij[1][0] + aij[2][0], aij[0][0] + aij[1][0] + aij[2][0]], 
            [aij[1][1] + aij[2][1], aij[0][1] + aij[1][1] + aij[2][1]], 
            [aij[1][2] + aij[2][2], aij[0][2] + aij[1][2] + aij[2][2]], color = color)

def draw_unitcell(ax, sites, aij, nrange, color = 'black'):
    """
    概要: 結晶の単位格子とその中の原子を3Dプロットに描画します。
    詳細説明: draw_box関数を呼び出して単位格子を描画し、その後、sitesリスト内の原子を分数座標から
              デカルト座標に変換してプロットします。nrangeは描画する単位格子の範囲を指定しますが、
              このスクリプトのmain関数では現在呼び出されていません。
    引数:
        :param ax: matplotlibの3D軸オブジェクト。
        :type ax: matplotlib.axes._subplots.Axes3DSubplot
        :param sites: サイト情報のリスト。各サイトは[atom_name, site_label, atomic_number, atomic_mass, charge, radius, color, position]の形式。
        :type sites: list of list
        :param aij: (3, 3)のndarray、格子ベクトルa, b, cを表す。
        :type aij: numpy.ndarray
        :param nrange: 描画する単位格子の範囲。[[xmin, xmax], [ymin, ymax], [zmin, zmax]]の形式。
        :type nrange: list of list of int
        :param color: 単位格子を描画する線の色。デフォルトは'black'。
        :type color: str
    戻り値:
        :returns: なし
        :rtype: None
    """
    draw_box(ax, aij, nrange, color)

    if sites is None:
        return

    for site in sites:
        name, label, z, M, q, r, color, pos = site
        pos01 = [reduce01(pos[0]), reduce01(pos[1]), reduce01(pos[2])]
        for iz in range(int(nrange[2][0]) - 1, int(nrange[2][1]) + 1):
         for iy in range(int(nrange[1][0]) - 1, int(nrange[1][1]) + 1):
          for ix in range(int(nrange[0][0]) - 1, int(nrange[0][1]) + 1):
            posn = [pos01[0] + ix, pos01[1] + iy, pos01[2] + iz]
            if    posn[0] < nrange[0][0] or nrange[0][1] < posn[0]  \
               or posn[1] < nrange[1][0] or nrange[1][1] < posn[1]  \
               or posn[2] < nrange[2][0] or nrange[2][1] < posn[2]:
                  continue

            x, y, z = fractional_to_cartesian(posn, aij)
            ax.scatter([x], [y], [z], marker = 'o', c = color, s = kr *r)


def main():
    """
    概要: マデルングポテンシャルの計算と結果のプロットを実行します。
    詳細説明:
        格子定数から格子ベクトルや逆格子ベクトルを計算し、その情報を表示します。
        指定された範囲rmax内で、原点にあるイオンに対するマデルングポテンシャルを
        単純総和法で計算します。
        計算されたポテンシャルは、距離rに対するグラフとして表示されます。
        プログラム起動時のコマンドライン引数でrmaxとnrを設定することができます。
    戻り値:
        :returns: なし
        :rtype: None
    """
    print("")
    print("Lattice parameters:", lattice_parameters)
    aij = cal_lattice_vectors(lattice_parameters)
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    inf = cal_metrics(lattice_parameters)
    gij = inf['gij']
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    volume = cal_volume(aij)
    print("Volume: {:12.4g} A^3".format(volume))

    print("")
    print("Unit cell volume: {:12.4g} A^3".format(volume))
    Raij  = cal_reciprocal_lattice_vectors(aij)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rinf  = cal_metrics(Rlatt)
    Rgij  = Rinf['gij']
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    Rvolume = cal_volume(Raij)
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))

# Calculate the range of unit cells
    nxmax = int(rmax / lattice_parameters[0]) + 1
    nymax = int(rmax / lattice_parameters[1]) + 1
    nzmax = int(rmax / lattice_parameters[2]) + 1
    print("")
    print("nmax:", nxmax, nymax, nzmax)

# Calculate Madelung potential around the zero-th ion
# First store differential potential to MPdiff
    rlist  = [rmin + i * rstep for i in range(nr)]
    MPdiff = np.zeros(nr)
    name0, label0, z0, M0, q0, r0, color0, pos0 = sites[0]
    Ke = e * e / 4.0 / pi / e0                  # in MKS
    for iz in range(-nzmax, nzmax+1):
     for iy in range(-nymax, nymax+1):
      for ix in range(-nxmax, nxmax+1):
        for isite1 in range(len(sites)):
            site1 = sites[isite1]
            name1, label1, z1, M1, q1, r1, color1, pos1 = site1
            r  = distance(pos0, pos1 + np.array([ix, iy, iz]), gij)
            ir = int((r - rmin) / rstep)
            if r < rmin or ir < 0 or nr <= ir:
                 continue

            MPdiff[ir] += Ke * q1 / (r * 1.0e-10) / e   # in eV

#                print("  {:4} ({:8.4g}, {:8.4g}, {:8.4g}) - {:4} ({:8.4g}, {:8.4g}, {:8.4g}) + ({:2d}, {:2d}, {:2d}): dis = {:10.4g} A"
#                    .format(label0, pos0[0], pos0[1], pos0[2], label1, pos1[0], pos1[1], pos1[2], ix, iy, iz, dis))

    print("")
    print("r (A)      Madelung potential (eV)")
    MP = np.empty(nr)
    MP[0] = MPdiff[0]
    print("{:10.4g}   {:12.6g}".format(rlist[0], MP[0]))
    for i in range (1, len(MPdiff)):
        MP[i] = MP[i-1] + MPdiff[i]
        print("{:10.4g}   {:12.6g}".format(rlist[i], MP[i]))
    
    fig = plt.figure(figsize = figsize)
    ax = fig.add_subplot(111)

    ax.plot(rlist, MP)
    ax.set_xlabel('r / angstrom')
    ax.set_ylabel('Electrostatic potential / eV')

    plt.show()

    
    terminate()


if __name__ == '__main__':
    argv = sys.argv
    narg = len(argv)
    if narg >= 2:
        rmax = float(argv[1])
    if narg >= 3:
        nr = int(argv[2])

    main()