"""
エワルド法を用いて結晶のマデルングポテンシャルを計算するスクリプトです。

概要:
    このスクリプトは、指定された結晶の格子パラメータとサイト情報に基づき、
    エワルド法を適用してマデルングポテンシャルを計算します。
    実空間和、逆空間和、および自己項の3つの成分を合計して、最終的なポテンシャルを求めます。

詳細説明:
    プログラムは以下の手順でマデルングポテンシャルを計算します。
    1.  初期化された格子パラメータとサイト情報（原子の種類、電荷、位置など）を使用します。
    2.  コマンドライン引数からエワルドパラメータ alpha と計算精度 prec を受け取ることができます。
    3.  格子ベクトル、逆格子ベクトル、体積、および関連するメトリックテンソルを計算し表示します。
    4.  エワルドパラメータに基づき、実空間および逆空間の計算範囲（rdmax, G2max）を決定します。
    5.  選択された中心サイトに対する実空間和 (UC1) を計算します。この項は実空間でのクーロン相互作用を表します。
    6.  逆空間和 (UC2) を計算します。この項は逆格子空間でのクーロン相互作用を表し、高速フーリエ変換に似た形式です。
    7.  自己相互作用項 (UC3) を計算します。これは原子自身の電場による自己エネルギーを補正する項です。
    8.  これら3つの項を合計し、マデルングポテンシャルおよびマデルング定数をJouleとeV単位で出力します。
    9.  計算にかかった時間も合わせて表示されます。

関連リンク:
    crystal_MP_Ewald_usage
    このスクリプトは tkcrystalbase.py モジュールに定義された関数を使用します。
"""

import sys
import os
import time
from math import erf, erfc
from numpy import sin, cos, tan, arcsin, arccos, arctan, exp, log, log10, sqrt
import numpy as np
from numpy import linalg as la
import matplotlib.pyplot as plt

from tkcrystalbase import *


pi          = 3.14159265358979323846
pi2         = pi + pi
torad       = 0.01745329251944 # rad/deg";
todeg       = 57.29577951472   # deg/rad";
basee       = 2.71828183

h           = 6.6260755e-34    # Js";
h_bar       = 1.05459e-34      # "Js";
hbar        = h_bar
c           = 2.99792458e8     # m/s";
e           = 1.60218e-19      # C";
me          = 9.1093897e-31    # kg";
mp          = 1.6726231e-27    # kg";
mn          = 1.67495e-27      # kg";
u0          = 4.0 * 3.14*1e-7; # . "Ns<sup>2</sup>C<sup>-2</sup>";
e0          = 8.854418782e-12; # C<sup>2</sup>N<sup>-1</sup>m<sup>-2</sup>";
e2_4pie0    = 2.30711e-28      # Nm<sup>2</sup>";
a0          = 5.29177e-11      # m";
kB          = 1.380658e-23     # JK<sup>-1</sup>";
NA          = 6.0221367e23     # mol<sup>-1</sup>";
R           = 8.31451          # J/K/mol";
F           = 96485.3          # C/mol";
g           = 9.81             # m/s2";


# Lattice parameters (angstrom and degree)
#lattice_parameters = [ 5.62, 5.62, 5.62, 60.0, 60.0, 60.0]
lattice_parameters = [ 5.62, 5.62, 5.62, 90.0, 90.0, 90.0]
#lattice_parameters = [ 1.0, 1.0, 1.0, 90.0, 90.0, 90.0]

# Site information (atom name, site label, atomic number, atomic mass, charge, radius, color, position)
sites = [
         ['Na', 'Na1', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.0, 0.0])]
        ,['Na', 'Na2', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.0, 0.5, 0.5])]
        ,['Na', 'Na3', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.0, 0.5])]
        ,['Na', 'Na4', 11, 22.98997, +1.0, 0.7, 'red',  np.array([0.5, 0.5, 0.0])]
        ,['Cl', 'Cl1', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.0, 0.0])]
        ,['Cl', 'Cl2', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.5, 0.5, 0.5])]
        ,['Cl', 'Cl3', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.0, 0.5])]
        ,['Cl', 'Cl4', 17, 35.4527,  -1.0, 1.4, 'blue', np.array([0.0, 0.5, 0.0])]
        ]

# Minimum distance to judge idential site
rmin = 0.1

# Ewald alpha parameter
ew_alpha = 0.3

# Precision
prec = 1.0e-5


def usage():
    """
    概要:
        プログラムの正しい使用方法を標準出力に表示します。
    詳細説明:
        プログラムをコマンドラインから実行する際の引数のフォーマットと例を示します。
    """
    print("")
    print("Usage: python {} alpha prec".format(argv[0]))
    print("   ex: python {} {} {}".format(argv[0], ew_alpha, prec))
    print("")

def terminate():
    """
    概要:
        使用方法を表示した後、プログラムを終了します。
    詳細説明:
        usage() 関数を呼び出し、その後にプログラムを強制終了します。
    """
    usage()
    exit()
    

def main():
    """
    概要:
        エワルド法によりマデルングポテンシャルを計算し、結果を表示します。
    詳細説明:
        この関数は、設定された格子パラメータとサイト情報に基づき、
        エワルド法を用いて結晶のマデルングポテンシャルを計算し、その結果を標準出力に表示します。
        具体的な計算手順は以下の通りです。

        1.  設定された格子パラメータから格子ベクトル、メトリックテンソル、単位胞の体積を計算し、表示します。
            これには cal_lattice_vectors と cal_metrics 関数が使用されます。
        2.  逆格子パラメータ、逆格子ベクトル、逆格子メトリックテンソル、逆格子単位胞の体積を計算し、表示します。
            これには cal_reciprocal_lattice_vectors と cal_reciprocal_lattice_parameters 関数が使用されます。
        3.  コマンドライン引数またはデフォルト設定で指定されたエワルドパラメータ ew_alpha と計算精度 prec を表示します。
        4.  これらのパラメータに基づき、実空間和の最大距離 rdmax と逆空間和の最大Gベクトル二乗値 G2max を決定し、
            それぞれに対応する最大繰返し回数 nrmax および hgmax を推定し、表示します。
        5.  実空間和 UC1 の計算を実行します。これは、中心サイトと周期的に配置された他のサイトとの間のクーロン相互作用を、
            erfc 関数を用いて収束させた合計です。
        6.  逆空間和 UC2 の計算を実行します。これは逆格子空間における電荷分布のフーリエ成分の相互作用の合計で、
            Gベクトルが G2max を超えない範囲で計算されます。
        7.  自己相互作用項 UC3 を計算します。これはエワルド法の導入により生じる、
            原子自身の電場による自己エネルギーを補正する項です。
        8.  各計算フェーズ（実空間和、逆空間和、合計）にかかった時間を表示します。
        9.  計算された UC1, UC2, UC3 の3つの項を合計し、最終的なマデルングポテンシャル MP を算出します。
        10. 算出されたマデルングポテンシャル MP をJoule単位とeV単位で表示します。
        11. 選択された中心サイトの電荷 qi と格子定数 a を用いて、マデルング定数を計算し、表示します。
    """
    print("")
    print("Lattice parameters:", lattice_parameters)
    aij = cal_lattice_vectors(lattice_parameters)
    print("Lattice vectors:")
    print("  ax: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[0][0], aij[0][1], aij[0][2]))
    print("  ay: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[1][0], aij[1][1], aij[1][2]))
    print("  az: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(aij[2][0], aij[2][1], aij[2][2]))
    inf = cal_metrics(lattice_parameters)
    gij = inf['gij']
    print("Metric tensor:")
    print("  gij: ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[0]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[1]))
    print("       ({:10.4g}, {:10.4g}, {:10.4g}) A".format(*gij[2]))
    volume = cal_volume(aij)
    print("Volume: {:12.4g} A^3".format(volume))

    print("")
    print("Unit cell volume: {:12.4g} A^3".format(volume))
    Raij  = cal_reciprocal_lattice_vectors(aij)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rinf  = cal_metrics(Rlatt)
    Rgij  = Rinf['gij']
    print("Reciprocal lattice parameters:", Rlatt)
    print("Reciprocal lattice vectors:")
    print("  Rax: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[0]))
    print("  Ray: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[1]))
    print("  Raz: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Raij[2]))
    print("Reciprocal lattice metric tensor:")
    print("  Rgij: ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[0]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[1]))
    print("        ({:10.4g}, {:10.4g}, {:10.4g}) A^-1".format(*Rgij[2]))
    Rvolume = cal_volume(Raij)
    print("Reciprocal unit cell volume: {:12.4g} A^-3".format(Rvolume))

    nsites = len(sites)
    
    print("")
    print("Ewald parameters")
    print("  alpha:", ew_alpha)
    norder = -log10(prec)
    print("  precision = {} = 10^-{}".format(prec, norder))

    rdmax     = (2.26 + 0.26 * norder) / ew_alpha
    erfc_rdmax = erfc(ew_alpha * rdmax)
    print("  RDmax = {} A, where erfc(alpha*RDmax) = {}".format(rdmax, erfc_rdmax));

    lsin  = np.empty(3, dtype = float)
    nrmax = np.empty(3, dtype = int)
    lsin[0] = sin(torad * lattice_parameters[3])
    lsin[1] = sin(torad * lattice_parameters[4])
    lsin[2] = sin(torad * lattice_parameters[5])
    nrmax[0] = int(rdmax / sqrt(gij[0][0] * lsin[1] * lsin[2])) + 1
    nrmax[1] = int(rdmax / sqrt(gij[1][1] * lsin[2] * lsin[0])) + 1
    nrmax[2] = int(rdmax / sqrt(gij[2][2] * lsin[0] * lsin[1])) + 1
    print("  nrmax:", *nrmax)

    cal_N = int(4.0 / 3.0 * pi * rdmax**3 / volume * nsites)
    print("  cal_N(real):", cal_N)

    G2max = ew_alpha**2 / pi**2 * (-log(prec))
    print("  G2max:", G2max)
    print("      exp(-pi^2 * G2max^2 / alpha^2) = ", exp(-pi**2 * G2max**2 / ew_alpha**2))
    lsin[0] = sin(torad * Rlatt[3])
    lsin[1] = sin(torad * Rlatt[4])
    lsin[2] = sin(torad * Rlatt[5])
    hgmax = np.empty(3, dtype = int)
    hgmax[0] = int(sqrt(G2max / (Rgij[0][0] * lsin[1] * lsin[2]))) + 1
    hgmax[1] = int(sqrt(G2max / (Rgij[1][1] * lsin[0] * lsin[2]))) + 1
    hgmax[2] = int(sqrt(G2max / (Rgij[2][2] * lsin[0] * lsin[1]))) + 1
    print("  hgmax:", *hgmax)

    cal_N = int(4.0 / 3.0 * pi * G2max**1.5 / Rvolume * nsites)
    print("  cal_N(reciprocal):", cal_N)

    namei, labeli, zi, Mi, qi, ri, colori, pos_i = sites[0]

    stime1 = time.time()
    UC1 = 0.0
    for iz in range(-nrmax[2], nrmax[2]+1):
     for iy in range(-nrmax[1], nrmax[1]+1):
      for ix in range(-nrmax[0], nrmax[0]+1):
        for j in range(nsites):
            namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[j]
            rij  = distance(pos_i, pos_j + np.array([ix, iy, iz]), gij)
            
            if rij < rmin:
                 continue

            erfcar = erfc(ew_alpha * rij)
            UC1 += qj * erfcar / (rij * 1.0e-10)   # in eV
    etime1 = time.time()

    origin = np.array([0.0, 0.0, 0.0])
    UC2 = 0.0
    Kexp = pi * pi / ew_alpha / ew_alpha
    Krec = 1.0 / pi / (volume * 1.0e-30)
#    for l in range(-hgmax[2], hgmax[2]+1):
    for l in range(0, hgmax[2]+1):
     for k in range(-hgmax[1], hgmax[1]+1):
      for h in range(-hgmax[0], hgmax[0]+1):
          G2 = distance2(origin, np.array([h, k, l]), Rgij)
          if G2 == 0.0 or G2 > G2max:
              continue

          phi_i  = pi2 * (h * pos_i[0] + k * pos_i[1] + l * pos_i[2])
          cosphi_i = cos(phi_i)
          sinphi_i = sin(phi_i)

          cossum_j = 0.0
          sinsum_j = 0.0
          for j in range(nsites):
              namej, labelj, zj, Mj, qj, rj, colorj, pos_j = sites[j]

              phi_j = pi2 * (h * pos_j[0] + k * pos_j[1] + l * pos_j[2])
              cossum_j += qj * cos(phi_j)
              sinsum_j += qj * sin(phi_j)

          fcal = cosphi_i * cossum_j + sinphi_i * sinsum_j
          if l != 0:
            fcal *= 2.0
          expg = exp(-Kexp * G2) / (G2 * 1.0e+20)
          UC2 += Krec * expg * fcal
    etime2 = time.time()

    UC3 = qi * 2.0 * (ew_alpha * 1.0e10) / sqrt(pi)

    MP = UC1 + UC2 - UC3
    etime3 = time.time()

# Coefficient to calculate electrostatic potential
    Ke = e * e / 4.0 / pi / e0

    print("")
    print("Time for real space sum     : {:6}".format(etime1 - stime1))
    print("Time for real reciprocal sum: {:6}".format(etime2 - etime1))
    print("Total time                  : {:6}".format(etime3 - stime1))
    
    print("  Madelung potential: {:12.6g} J  (= {:12.6g} + {:12.6g} - {:12.6g})".format(Ke * MP, Ke * UC1, Ke * UC2, Ke * UC3))
    print("  Madelung potential: {:12.6g} eV (= {:12.6g} + {:12.6g} - {:12.6g})".format(Ke / e * MP, Ke / e * UC1, Ke / e * UC2, Ke / e * UC3))
# Charge is represented by q0 to define Madelung constant
# Lattice parameter a is represented by q0 to define Madelung constant    
    print("  Madelung constant: {:14.8g}".format(0.5 * MP / abs(qi) * (lattice_parameters[0] * 1.0e-10)))


    terminate()


if __name__ == '__main__':
    argv = sys.argv
    narg = len(argv)
    if narg >= 2:
        ew_alpha = float(argv[1])
    if narg >= 3:
        prec = float(argv[2])

    main()