#https://dora.bk.tsukuba.ac.jp/~takeuchi/?%E9%87%8F%E5%AD%90%E5%8A%9B%E5%AD%A6%E2%85%A0%2F%E7%90%83%E5%AF%BE%E7%A7%B0%E4%BA%95%E6%88%B8%E5%9E%8B%E3%83%9D%E3%83%86%E3%83%B3%E3%82%B7%E3%83%A3%E3%83%AB
#https://dora.bk.tsukuba.ac.jp/~takeuchi/?%E9%87%8F%E5%AD%90%E5%8A%9B%E5%AD%A6%E2%85%A0%2F%EF%BC%93%E6%AC%A1%E5%85%83%E8%AA%BF%E5%92%8C%E6%8C%AF%E5%8B%95%E5%AD%90#fe164113

import sys
import numpy as np
from numpy import sqrt
from math import factorial
from scipy.special import spherical_jn, genlaguerre
import scipy.optimize as opt
import matplotlib.pyplot as plt


# 定数
url = "https://dora.bk.tsukuba.ac.jp/~takeuchi/?%E9%87%8F%E5%AD%90%E5%8A%9B%E5%AD%A6%E2%85%A0%2F%E7%90%83%E5%AF%BE%E7%A7%B0%E4%BA%95%E6%88%B8%E5%9E%8B%E3%83%9D%E3%83%86%E3%83%B3%E3%82%B7%E3%83%A3%E3%83%AB"
me = 9.10938356e-31   # kg
hbar = 1.0545718e-34    # J·s
eV   = 1.60218e-19      # J/eV

lstr = ['s', 'p', 'd', 'f', 'g', 'h', 'l']


meff = 0.067
R = 5e-9 # m
Z = 1

mode = "cal"


nargs = len(sys.argv)
if nargs >= 2: mode = sys.argv[1]


def get_zeros(func, xmin = 0.0, xmax = 10.0, dx = 0.1, eps = 1.0e-10, nmaxiter = 50, h = 1.0e-10, dump = 1.0, print_level = 0):
    """
    関数 func(x)の零点を計算する
    """

    zeros = []
    nnode = 0
    prev_r = 0.0
    prev_f = 1.0
    for r in np.arange(xmin, xmax + dx, dx):
        f = func(r)

        if f == 0.0:
            zeros.append(r)
        elif r > 0.0 and f * prev_f < 0.0:
            r0 = prev_r
            r1 = r
            for i in range(nmaxiter):
                if print_level:
                    print(f" iter#{i}/{nmaxiter}: nnode={nnode} dr={r1:8.4g} - {r0:8.4g} = {r1-r0:8.4g}")
                if abs(r1 - r0) < eps: 
                    zeros.append(r1)
                    break

                r0 = r1
                f1 = func(r1)

                rh = r1 + h
                fh = func(rh)
                fdiff = (fh - f1) / (rh - r1)

                r1 = r1 - f1 / fdiff / dump

            nnode += 1

        prev_r = r
        prev_f = f

    return zeros

def get_bessel_zeros(l, xmin = 0.0, xmax = 10.0, dx = 0.1, print_level = 0):
    """
    球Bessel関数j_l(r)の零点を計算する
    """

    return get_zeros(lambda x: spherical_jn(l, x), xmin, xmax, dx, print_level = print_level)


def energy_level(meff, R, n, l, zeros):
    """
    量子井戸のエネルギー準位を計算する
    :param m_star: 効果的質量 (kg)
    :param R: 量子ドットの半径 (m)
    :param n: 主量子数
    :param l: 軌道角運動量量子数
    :param zeros: 球Bessel関数の零点のリスト zeros[l][n-1]
    :return: エネルギー準位 (eV)
    """
    # 球Bessel関数のn番目の零点αnlを求め、波数ベクトルknlを計算
    if len(zeros[l]) <= n - 1: return None, None
    aR = zeros[l][n-1]

    knl = aR / R
    Enl = hbar * hbar * knl * knl / 2.0 / meff / me
    Enl_eV = Enl / eV
    
    return Enl_eV, aR


def cal():
    print()
    print("Energy levels for quantum sphere")
    print(f"  effective mass: {meff} me")
    print(f"  radius: {R*1.0e9} nm")

    nmax = 5
    zeros = []
    print("Zero points:")
    for l in range(0, nmax):
        zero_points = get_bessel_zeros(l, xmin = 0.1, xmax = 10.0, dx = 0.1, print_level = 0)
        zeros.append(zero_points)
        print(f"l={l}:", zero_points)

    level_list = []
    for n in range(1, nmax):
       for l in range(0, nmax):
            Enl, R_zero = energy_level(meff, R, n, l, zeros)
            if Enl is None: continue

            level_list.append({"E": Enl, "label": f"{n+l}{lstr[l]}", "n": n, "l": l, 
                                "R0": R_zero})

    print(f"l n l+n  E(orb)")
    for d in sorted(level_list, key=lambda x: x["E"]):
        E     = d["E"]
        label = d["label"]
        n     = d["n"]
        l     = d["l"]
        R0 = d["R0"]
        print(f"{l} {n} {l+n}    E({label})={E:12.8g} eV  alpha_{n}_{l}={R0:12.8g}  alpha_{n}_{l}^2={R0*R0:12.8g}")
    print(f"see {url} for the definitions of quantum numbers")

def plot_spherical_bessel(lmax, rmax, rmesh = 500):
    """
    球Bessel関数 j_l(r) を描画する
    :param lmax: 軌道角運動量量子数
    :param rmax: r の最大値
    :param nmesh: r の分割数
    """

    plt.figure(figsize=(8, 6))
    plt.title(f"Spherical Bessel Function $j_l(kr)$", fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    r = np.linspace(0, rmax, rmesh)
    for l in range(lmax + 1):
        j_l = spherical_jn(l, r)
        zero_points = get_bessel_zeros(l, xmin = 0.1, xmax = 10.0, dx = 0.1, print_level = 1)
        print("zeros: l=", l, zero_points)

        plt.plot(r, j_l, label = f"$j_{l}(kr)$")
        plt.plot(zero_points, [0.0] * len(zero_points), label = f"zero points of $j_{l}(kr)$", linestyle = '', marker = 'o')

    plt.axhline(y=0, color='red', linestyle='--', linewidth = 0.5)
    plt.axvline(x=0, color='red', linestyle='--', linewidth = 0.5)
    plt.xlabel(r"$k$$\cdot$$r$", fontsize=16)
    plt.ylabel(r"$j_l$($k$$\cdot$$r$)", fontsize=16)
#    plt.grid(True)

    plt.legend(fontsize=12)
    
    plt.show()

def plot_H(nmax, rmax):
    a0 = 1.0

    plt.figure(figsize=(8, 6))
    plt.title("Radial Wave Function $R_{20}(r)$ for Hydrogen Atom", fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    r = np.linspace(0, 20, 500)
    for n in range(nmax + 1):
        for l in range(n):
            rho = 2 * Z * r / (n * a0)
            Lnl = genlaguerre(n - l - 1, 2 * l + 1)(rho)
            R_nl = (rho ** l) * np.exp(-rho / 2) * Lnl
#            R_nl = sqrt((2 * Z / (n * a0))**3 \
#                * factorial(n - l - 1) / (2 * n * factorial(n + l))) \
#                * R_nl
            plt.plot(r, R_nl, label = f"$R_{n}$$_{l}$($r$)")

    plt.axhline(y=0, color='red', linestyle='--', linewidth = 0.5)
    plt.axvline(x=0, color='red', linestyle='--', linewidth = 0.5)
    plt.xlabel("r (in units of $a_0$)", fontsize=16)
    plt.ylabel("$R_{nl}$($r$)", fontsize=16)
#    plt.grid(True)

    plt.legend(fontsize=12)
    
    plt.show()

def main():
    if mode == 'cal':
        cal()
    elif mode == 'plot':
        plot_spherical_bessel(lmax = 3, rmax = 10.0)
    elif mode == 'plotH':
        plot_H(nmax = 3, rmax = 10.0)
    else:
        print(f"Error: Invalid mode={mode}")
        exit()

if __name__ == '__main__':
    main()
    
    