"""
水素原子の波動関数とエネルギー準位の計算モジュール

関連リンク:
:doc:`tkWavefunction_H_usage`
"""

#https://zenn.dev/shittoku_xxx/articles/061ad99166afcc

import re
import math
import numpy as np
from numpy import abs, exp, sqrt, cos, sin, arccos, arctan2, angle, pi
from scipy.special import factorial
try:
    from scipy.special import sph_harm_y as sph_harm
except ImportError:
    from scipy.special import sph_harm
from scipy.constants import physical_constants


# 定数
a0 = physical_constants['Bohr radius'][0] * 1e10  # Bohr半径。オングストローム単位に変換
#a_0 = 4 * np.pi * epsilon_0 * h_bar**2 / (m_e * e**2)  # ボーア半径 (m)
h = 6.62607015e-34  # プランク定数 (J·s)
h_bar = h / (2.0 * pi)  # ディラック定数 (J·s)
m_e = 9.10938356e-31  # 電子の質量 (kg)
e = 1.602176634e-19  # 電子の電荷 (C)
e0 = 8.854187817e-12  # 真空の誘電率 (F/m)
pi = np.pi

# エネルギー準位を計算する関数
def En(n):
    """
    主量子数に対する水素原子のエネルギー準位を計算する。

    :param n: int, 主量子数
    :returns: float, エネルギー準位 (eV)
    """
    E = - (m_e * e**4) / (8 * e0**2 * h**2) / n**2
    return E / e

def get_orb_name(n, l, m):
    """
    量子数から軌道の名前（例: '1s', '2px'など）を取得する。

    :param n: int, 主量子数
    :param l: int または str, 方位量子数
    :param m: int または str, 磁気量子数
    :returns: str, 軌道名
    """
    print(n,l,m)
    names1 = ['s']  # for l=0
    names2 = ['py', 'pz', 'px']  # for l=1, m=-1, 0, 1
    names3 = ['dxy', 'dyz', 'd3z2-r2', 'dzx', 'dx2-y2']  # for l=2, m=-2,-1,0,1,2
    names4 = ['fy(3x2-y2)', 'fxyz', 'fy(5z2-r2)', 'fz(5z2-3r2)',
              'fx(5z2-r2)', 'fz(x2-y2)', 'fx(x2-3y2)']  # for l=3, m=-3,-2,-1,0,1,2,3
    names = [names1, names2, names3, names4]

    if type(l) is str:
        pNames = l
    else:
        pNames = names[l]

    if type(m) == str:
        name = l + m
    else:
        name = pNames[m + l]
    
    return f"{n}{name}"

def get_qnumbers(n, l, m):
    """
    軌道を表す文字列または数値から量子数(n, l, m)の組を取得する。

    :param n: int, 主量子数
    :param l: int または str, 方位量子数
    :param m: int または str, 磁気量子数
    :returns: tuple, (n, l, m) の量子数の組
    """
    qnumber_l = {"s": 0, "p": 1, "d": 2, "f": 3, "g": 4}
    qnumber_m = {"": 0, 
                 "x": 1, "y": -1, "z": 0, 
                 "xy+": 1, "xy-": -1,
                 "xy": -2, "yz": -1, "z2": 0, "zx": 1, "x2-y2": 2,
                 "y(3x2-y2)": -3, "xyz": -2, "y(5z2-r2)": -1, 
                 "z(5z2-3r2)": 0, "x(5z2-r2)": 1, "z(x2-y2)": 2, "x(x2-3y2)": 3}
    
    if type(l) is str:
        l = qnumber_l.get(l, None)
    if type(m) is str:
        m = qnumber_m.get(m, None)
    return n, l, m

def analyze_function(s):
    """
    文字列表現を解析し、波動関数のタイプとパラメータを取得する。

    :param s: str, 解析対象の文字列 (例: '3dxy', '210r')
    :returns: tuple, (type_str, list_of_params) の形式で返す
    """
#数字だけ
    match = re.match(r"^([+-]?\d)([+-]?\d)([+-]?\d)([cria2]+)?$", s)
    if match:
        n = int(match.group(1))
        l = int(match.group(2))
        m = int(match.group(3))
        t = match.group(4)
        if t is None: t = 'r'
        return 'c', [n, l, m, t]
    
# 3dxyのような文字列形式
    match = re.match(r"^(\d)([a-z])([a-z]+[\-\+]?)([cria2]+)?$", s)
    if match:
        n = int(match.group(1))
        l = match.group(2)
        m = match.group(3)
        t = match.group(4)
        if t is None: t = 'r'
        return 'r', [n, l, m, t]

#その他 
    return 'f', [s]

def get_by_type(f, rettype = 'c'):
    """
    指定された戻り値の型に応じて複素数の成分を返す。

    :param f: complex, 評価する複素数
    :param rettype: str, 戻り値のタイプ ('r': 実部, 'i': 虚部, 'a': 絶対値, 'a2': 絶対値の2乗, 'c': 実部と虚部)
    :returns: tuple, 要求された成分と位相、または実部と虚部のタプル
    """
    if rettype == 'r':
        return f.real, angle(f)
    elif rettype == 'i':
        return f.imag, angle(f)
    elif rettype == 'a':
        return abs(f), angle(f)
    elif rettype == 'a2':
        fnorm = abs(f)
        return fnorm**2, angle(f)
    else:
        return f.real, f.imag

#動径波動関数:
#kRnl = 1.0
kRnl = 1.0 / sqrt(4.0 * pi) # integ(4 * pi * r**2 * Rnl(r)**2)で規格化
def Rnl_if(r, n, l, Z = 1.0):
    """
    条件分岐による解析的な動径波動関数を計算する。

    :param r: float または ndarray, 距離
    :param n: int, 主量子数
    :param l: int, 方位量子数
    :param Z: float, 原子番号 (デフォルトは 1.0)
    :returns: float または ndarray, 動径波動関数の値。サポート外の場合は None
    """
    rho = 2 * Z * r / n / a0
    if n == 1 and l == 0:  # 1s
        Rr = 2 * (Z / a0)**(3/2) * np.exp(-rho / 2)
    elif n == 2 and l == 0:  # 2s
        Rr = (1 / (2 * np.sqrt(2))) * (Z / a0) ** (3/2) * (2 - rho) * np.exp(-rho / 2)
    elif n == 2 and l == 1:  # 2p
        Rr = (1 / (2 * np.sqrt(6))) * (Z / a0) ** (3/2) * rho * np.exp(-rho / 2)
    elif n == 3 and l == 0:  # 3s
        Rr = (1 / (9 * np.sqrt(3))) * (Z / a0) ** (3/2) * (6 - 6 * rho + rho ** 2) * np.exp(-rho / 2)
    elif n == 3 and l == 1:  # 3p
        Rr = (1 / (9 * np.sqrt(6))) * (Z / a0) ** (3/2) * (4 - rho) * rho * np.exp(-rho / 2)
    elif n == 3 and l == 2:  # 3d
        Rr = (1 / (9 * np.sqrt(30))) * (Z / a0) ** (3/2) * rho ** 2 * np.exp(-rho / 2)
    elif n == 4 and l == 0:  # 4s
        Rr = 1/768 * (1/a0)**(3/2) * (192 - 144*r/a0 + 24*r**2/a0**2 - r**3/a0**3)*np.exp(-r/4/a0)
    else:
        Rr = None
#        print(f"\nError in tkWavefunctin_H.Rnl_if(): Unsuport (n,l)=({n}, {l})\n")
#        exit()

    if Rr is None: return None
    return kRnl * Rr

def Rnl(r, n, l):
    """
    ラゲールの陪多項式を用いて動径波動関数を計算する。

    :param r: float または ndarray, 距離
    :param n: int, 主量子数
    :param l: int, 方位量子数
    :returns: float または ndarray, 動径波動関数の値
    """
    zeta = 2.0 * r / n / a0
    L_coff = (-1)**(2 * l + 1) * factorial(n + l)
    f_coff = - ((2.0 / n / a0)**3 * factorial(n - l - 1) / 2.0 / n / factorial(n + l)**3)**(1.0 / 2.0)
    K = f_coff * np.exp(-zeta / 2.0) * zeta**l * L_coff

    return kRnl * K * assoc_laguerre(zeta, n - l - 1, 2 * l + 1)

# 球面調和関数
KPhi = 1.0 / sqrt(2.0 * pi)
def Ylm_phi_r(phi, m):
    """
    球面調和関数の実数型の方位角成分を計算する。

    :param phi: float または ndarray, 方位角
    :param m: int, 磁気量子数
    :returns: float または ndarray, 方位角成分の値
    """
    if m >= 0:
        return KPhi * exp(1j * m * phi).real
    else:
        return KPhi * exp(-1j * m * phi).imag

def Ylm_phi(phi, m):
    """
    球面調和関数の複素数型の方位角成分を計算する。

    :param phi: float または ndarray, 方位角
    :param m: int, 磁気量子数
    :returns: complex または ndarray, 方位角成分の値
    """
    return KPhi * exp(1j * m * phi)

KYlm = sqrt(4.0 * pi)
def Ylm_theta(theta, l, m):
    """
    ルジャンドルの陪関数を用いて球面調和関数の極角成分を計算する。

    :param theta: float または ndarray, 極角
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :returns: float または ndarray, 極角成分の値
    """
    ma = m if m >= 0 else -m
    KTheta = sqrt((2 * l + 1) / 2 * factorial(l - ma) / factorial(l + ma))
#    KTheta = sqrt((2 * l + 1) / (4 * pi) * factorial(l - m) / factorial(l + m))
#    KTheta /= KPhi
    if m >= 1 and m % 2 == 1:
        KTheta *= -1.0

    return KYlm * KTheta * lpmv(m, l, cos(theta))

def Ylm(theta, phi, l, m, phase_m = 0.0):
    """
    複素数型の球面調和関数を計算する。

    :param theta: float または ndarray, 極角
    :param phi: float または ndarray, 方位角
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase_m: float, 位相シフト (デフォルトは 0.0)
    :returns: complex または ndarray, 球面調和関数の値
    """
#    return exp(-1.0j * phase_m) * sph_harm(m, l, phi, theta)
    return exp(-1.0j * phase_m) * Ylm_phi(phi, m) * Ylm_theta(theta, l, m)

def Ylm_r(theta, phi, l, m, phase_m = 0.0):
    """
    実数型の球面調和関数を計算する。

    :param theta: float または ndarray, 極角
    :param phi: float または ndarray, 方位角
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase_m: float, 位相シフト (デフォルトは 0.0)
    :returns: float または ndarray, 実数型の球面調和関数の値
    """
#    return exp(-1.0j * phase_m) * sph_harm(m, l, phi, theta)
    return exp(-1.0j * phase_m) * Ylm_phi_r(phi, m) * Ylm_theta(theta, l, m)

def Ylm_xyz(x, y, z, l, m, phase_m = 0.0):
    """
    直交座標系から実数型の球面調和関数を計算する。

    :param x: float または ndarray, x座標
    :param y: float または ndarray, y座標
    :param z: float または ndarray, z座標
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase_m: float, 位相シフト (デフォルトは 0.0)
    :returns: float または ndarray, 球面調和関数の値
    """
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)

    return Ylm_r(theta, phi, l, m, phase_m)

def Ylm_real(theta, phi, l, m, phase_m = 0.0):
    """
    複素数型の球面調和関数から実関数表現（実部・虚部の組み合わせ）を計算する。

    :param theta: float または ndarray, 極角
    :param phi: float または ndarray, 方位角
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase_m: float, 位相シフト (デフォルトは 0.0)
    :returns: float または ndarray, 実関数表現の球面調和関数の値
    """
    if m >= 0:
        return Ylm(theta, phi, l, m, phase_m).real
    elif m < 0:
        return Ylm(theta, phi, l, -m, phase_m).imag

def Ylm_xyz_real(x, y, z, l, m, phase_m = 0.0):
    """
    直交座標系から実関数表現の球面調和関数を計算する。

    :param x: float または ndarray, x座標
    :param y: float または ndarray, y座標
    :param z: float または ndarray, z座標
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase_m: float, 位相シフト (デフォルトは 0.0)
    :returns: float または ndarray, 実関数表現の球面調和関数の値
    """
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)

    return Ylm_real(theta, phi, l, m, phase_m)

def psi_r(r, theta, phi, n, l, m, phase = 0.0):
    """
    球座標系での水素原子の複素波動関数を計算する。

    :param r: float または ndarray, 距離
    :param theta: float または ndarray, 極角
    :param phi: float または ndarray, 方位角
    :param n: int, 主量子数
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase: float, 位相シフト (デフォルトは 0.0)
    :returns: complex または ndarray, 波動関数の値
    """
    return exp(-1.0j * phase) * Rnl(r, n, l) * Ylm(theta, phi, l, m)

def psi_xyz(x, y, z, n, l, m, phase = 0.0):
    """
    直交座標系での水素原子の複素波動関数を計算する。

    :param x: float または ndarray, x座標
    :param y: float または ndarray, y座標
    :param z: float または ndarray, z座標
    :param n: int, 主量子数
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase: float, 位相シフト (デフォルトは 0.0)
    :returns: complex または ndarray, 波動関数の値
    """
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)
    return psi_r(r, theta, phi, n, l, m, phase)

def psi_r_real(r, theta, phi, n, l, m, phase = 0.0):
    """
    球座標系での水素原子の実関数表現の波動関数を計算する。

    :param r: float または ndarray, 距離
    :param theta: float または ndarray, 極角
    :param phi: float または ndarray, 方位角
    :param n: int, 主量子数
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase: float, 位相シフト (デフォルトは 0.0)
    :returns: float または ndarray, 実関数表現の波動関数の値
    """
    if m >= 0:
        f = psi_r(r, theta, phi, n, l, m)
        return f.real
    else:
        f = psi_r(r, theta, phi, n, l, -m)
        return f.imag

def psi_xyz_real(x, y, z, n, l, m, phase = 0.0):
    """
    直交座標系での水素原子の実関数表現の波動関数を計算する。

    :param x: float または ndarray, x座標
    :param y: float または ndarray, y座標
    :param z: float または ndarray, z座標
    :param n: int, 主量子数
    :param l: int, 方位量子数
    :param m: int, 磁気量子数
    :param phase: float, 位相シフト (デフォルトは 0.0)
    :returns: float または ndarray, 実関数表現の波動関数の値
    """
    r2 = x**2 + y**2 + z**2
    r = sqrt(r2)
    theta = arccos(z / r)
    phi = arctan2(y, x)
    if m >= 0:
        f = psi_r(r, theta, phi, l, m, phase_m)
        return f.real
    else:
        f = psi_r(r, theta, phi, l, -m, phase_m)
        return f.imag
