#https://zenn.dev/shittoku_xxx/articles/061ad99166afcc

import re
import math
import numpy as np
from numpy import abs, exp, sqrt, cos, sin, arccos, arctan2, angle, pi
from scipy.special import factorial
from scipy.special import sph_harm, lpmv, assoc_laguerre
from scipy.constants import physical_constants


# 定数
a0 = physical_constants['Bohr radius'][0] * 1e10  # Bohr半径。オングストローム単位に変換
#a_0 = 4 * np.pi * epsilon_0 * h_bar**2 / (m_e * e**2)  # ボーア半径 (m)
h = 6.62607015e-34  # プランク定数 (J·s)
h_bar = h / (2.0 * pi)  # ディラック定数 (J·s)
m_e = 9.10938356e-31  # 電子の質量 (kg)
e = 1.602176634e-19  # 電子の電荷 (C)
e0 = 8.854187817e-12  # 真空の誘電率 (F/m)
pi = np.pi

# エネルギー準位を計算する関数
def En(n):
    E = - (m_e * e**4) / (8 * e0**2 * h**2) / n**2
    return E / e

def get_orb_name(n, l, m):
    print(n,l,m)
    names1 = ['s']  # for l=0
    names2 = ['py', 'pz', 'px']  # for l=1, m=-1, 0, 1
    names3 = ['dxy', 'dyz', 'd3z2-r2', 'dzx', 'dx2-y2']  # for l=2, m=-2,-1,0,1,2
    names4 = ['fy(3x2-y2)', 'fxyz', 'fy(5z2-r2)', 'fz(5z2-3r2)',
              'fx(5z2-r2)', 'fz(x2-y2)', 'fx(x2-3y2)']  # for l=3, m=-3,-2,-1,0,1,2,3
    names = [names1, names2, names3, names4]

    if type(l) is str:
        pNames = l
    else:
        pNames = names[l]

    if type(m) == str:
        name = l + m
    else:
        name = pNames[m + l]
    
    return f"{n}{name}"

def get_qnumbers(n, l, m):
    qnumber_l = {"s": 0, "p": 1, "d": 2, "f": 3, "g": 4}
    qnumber_m = {"": 0, 
                 "x": 1, "y": -1, "z": 0, 
                 "xy+": 1, "xy-": -1,
                 "xy": -2, "yz": -1, "z2": 0, "zx": 1, "x2-y2": 2,
                 "y(3x2-y2)": -3, "xyz": -2, "y(5z2-r2)": -1, 
                 "z(5z2-3r2)": 0, "x(5z2-r2)": 1, "z(x2-y2)": 2, "x(x2-3y2)": 3}
    
    if type(l) is str:
        l = qnumber_l.get(l, None)
    if type(m) is str:
        m = qnumber_m.get(m, None)
    return n, l, m

def analyze_function(s):
#数字だけ
    match = re.match(r"^([+-]?\d)([+-]?\d)([+-]?\d)([cria2]+)?$", s)
    if match:
        n = int(match.group(1))
        l = int(match.group(2))
        m = int(match.group(3))
        t = match.group(4)
        if t is None: t = 'r'
        return 'c', [n, l, m, t]
    
# 3dxyのような文字列形式
    match = re.match(r"^(\d)([a-z])([a-z]+[\-\+]?)([cria2]+)?$", s)
    if match:
        n = int(match.group(1))
        l = match.group(2)
        m = match.group(3)
        t = match.group(4)
        if t is None: t = 'r'
        return 'r', [n, l, m, t]

#その他 
    return 'f', [s]

def get_by_type(f, rettype = 'c'):
    if rettype == 'r':
        return f.real, angle(f)
    elif rettype == 'i':
        return f.imag, angle(f)
    elif rettype == 'a':
        return abs(f), angle(f)
    elif rettype == 'a2':
        fnorm = abs(f)
        return fnorm**2, angle(f)
    else:
        return f.real, f.imag

#動径波動関数:
#kRnl = 1.0
kRnl = 1.0 / sqrt(4.0 * pi) # integ(4 * pi * r**2 * Rnl(r)**2)で規格化
def Rnl_if(r, n, l, Z = 1.0):
    rho = 2 * Z * r / n / a0
    if n == 1 and l == 0:  # 1s
        Rr = 2 * (Z / a0)**(3/2) * np.exp(-rho / 2)
    elif n == 2 and l == 0:  # 2s
        Rr = (1 / (2 * np.sqrt(2))) * (Z / a0) ** (3/2) * (2 - rho) * np.exp(-rho / 2)
    elif n == 2 and l == 1:  # 2p
        Rr = (1 / (2 * np.sqrt(6))) * (Z / a0) ** (3/2) * rho * np.exp(-rho / 2)
    elif n == 3 and l == 0:  # 3s
        Rr = (1 / (9 * np.sqrt(3))) * (Z / a0) ** (3/2) * (6 - 6 * rho + rho ** 2) * np.exp(-rho / 2)
    elif n == 3 and l == 1:  # 3p
        Rr = (1 / (9 * np.sqrt(6))) * (Z / a0) ** (3/2) * (4 - rho) * rho * np.exp(-rho / 2)
    elif n == 3 and l == 2:  # 3d
        Rr = (1 / (9 * np.sqrt(30))) * (Z / a0) ** (3/2) * rho ** 2 * np.exp(-rho / 2)
    elif n == 4 and l == 0:  # 4s
        Rr = 1/768 * (1/a0)**(3/2) * (192 - 144*r/a0 + 24*r**2/a0**2 - r**3/a0**3)*np.exp(-r/4/a0)
    else:
        Rr = None
#        print(f"\nError in tkWavefunctin_H.Rnl_if(): Unsuport (n,l)=({n}, {l})\n")
#        exit()

    if Rr is None: return None
    return kRnl * Rr

def Rnl(r, n, l):
    zeta = 2.0 * r / n / a0
    L_coff = (-1)**(2 * l + 1) * factorial(n + l)
    f_coff = - ((2.0 / n / a0)**3 * factorial(n - l - 1) / 2.0 / n / factorial(n + l)**3)**(1.0 / 2.0)
    K = f_coff * np.exp(-zeta / 2.0) * zeta**l * L_coff

    return kRnl * K * assoc_laguerre(zeta, n - l - 1, 2 * l + 1)

# 球面調和関数
KPhi = 1.0 / sqrt(2.0 * pi)
def Ylm_phi_r(phi, m):
    if m >= 0:
        return KPhi * exp(1j * m * phi).real
    else:
        return KPhi * exp(-1j * m * phi).imag

def Ylm_phi(phi, m):
    return KPhi * exp(1j * m * phi)

KYlm = sqrt(4.0 * pi)
def Ylm_theta(theta, l, m):
    ma = m if m >= 0 else -m
    KTheta = sqrt((2 * l + 1) / 2 * factorial(l - ma) / factorial(l + ma))
#    KTheta = sqrt((2 * l + 1) / (4 * pi) * factorial(l - m) / factorial(l + m))
#    KTheta /= KPhi
    if m >= 1 and m % 2 == 1:
        KTheta *= -1.0

    return KYlm * KTheta * lpmv(m, l, cos(theta))

def Ylm(theta, phi, l, m, phase_m = 0.0):
#    return exp(-1.0j * phase_m) * sph_harm(m, l, phi, theta)
    return exp(-1.0j * phase_m) * Ylm_phi(phi, m) * Ylm_theta(theta, l, m)

def Ylm_r(theta, phi, l, m, phase_m = 0.0):
#    return exp(-1.0j * phase_m) * sph_harm(m, l, phi, theta)
    return exp(-1.0j * phase_m) * Ylm_phi_r(phi, m) * Ylm_theta(theta, l, m)

def Ylm_xyz(x, y, z, l, m, phase_m = 0.0):
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)

    return Ylm_r(theta, phi, l, m, phase_m)

def Ylm_real(theta, phi, l, m, phase_m = 0.0):
    if m >= 0:
        return Ylm(theta, phi, l, m, phase_m).real
    elif m < 0:
        return Ylm(theta, phi, l, -m, phase_m).imag

def Ylm_xyz_real(x, y, z, l, m, phase_m = 0.0):
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)

    return Ylm_real(theta, phi, l, m, phase_m)

def psi_r(r, theta, phi, n, l, m, phase = 0.0):
    return exp(-1.0j * phase) * Rnl(r, n, l) * Ylm(theta, phi, l, m)

def psi_xyz(x, y, z, n, l, m, phase = 0.0):
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)
    return psi_r(r, theta, phi, n, l, m, phase)

def psi_r_real(r, theta, phi, n, l, m, phase = 0.0):
    if m >= 0:
        f = psi_r(r, theta, phi, n, l, m)
        return f.real
    else:
        f = psi_r(r, theta, phi, n, l, -m)
        return f.imag

def psi_xyz_real(x, y, z, n, l, m, phase = 0.0):
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)
    if m >= 0:
        f = psi_r(r, theta, phi, l, m, phase_m)
        return f.real
    else:
        f = psi_r(r, theta, phi, l, -m, phase_m)
        return f.imag

