"""
量子ドット (Quantum Dot) におけるエネルギー準位計算および関連関数のプロットスクリプト

概要:
    本スクリプトは、球対称ポテンシャル井戸モデルを用いた量子ドットのエネルギー準位を計算し、
    その結果を表示する機能、および関連する物理関数（球Bessel関数や水素原子の動径波動関数）を
    プロットする機能を提供する。

詳細説明:
    cal モードでは、指定された有効質量と半径を持つ量子ドットの電子のエネルギー準位を、
    球Bessel関数の零点を用いて計算します。
    plot モードでは、球Bessel関数のグラフを描画し、その零点を視覚化します。
    plotH モードでは、水素原子の動径波動関数を描画します。
    コマンドライン引数により実行モードを選択できます。

関連リンク:
    3DQD_usage
    量子力学I/球対称井戸型ポテンシャル: https://dora.bk.tsukuba.ac.jp/~takeuchi/?%E9%87%8F%E5%AD%90%E5%8A%9B%E5%AD%A6%E2%85%A0%2F%E7%90%83%E5%AF%BE%E7%A7%B0%E4%BA%95%E6%88%B8%E5%9E%8B%E3%83%9D%E3%83%86%E3%83%B3%E3%82%B7%E3%83%A3%E3%83%AB
    量子力学I/3次元調和振動子: https://dora.bk.tsukuba.ac.jp/~takeuchi/?%E9%87%8F%E5%AD%90%E5%8A%9B%E5%AD%A6%E2%85%A0%2F%EF%BC%93%E6%AC%A1%E5%85%83%E8%AA%BF%E5%92%8C%E6%8C%AF%E5%8B%95%E5%AD%90#fe164113
"""

import sys
import numpy as np
from numpy import sqrt
from math import factorial
from scipy.special import spherical_jn, genlaguerre
import scipy.optimize as opt
import matplotlib.pyplot as plt


# 定数
url = "https://dora.bk.tsukuba.ac.jp/~takeuchi/?%E9%87%8F%E5%AD%90%E5%8A%9B%E5%AD%A6%E2%85%A0%2F%E7%90%83%E5%AF%BE%E7%A7%B0%E4%BA%95%E6%88%B8%E5%9E%8B%E3%83%9D%E3%83%86%E3%83%B3%E3%82%B7%E3%83%A3%E3%83%AB"
me = 9.10938356e-31   # kg
hbar = 1.0545718e-34    # J·s
eV   = 1.60218e-19      # J/eV

lstr = ['s', 'p', 'd', 'f', 'g', 'h', 'l']


meff = 0.067
R = 5e-9 # m
Z = 1

mode = "cal"


if __name__ == "__main__":
    nargs = len(sys.argv)
    if nargs >= 2: mode = sys.argv[1]
    if nargs >= 3: R = 1.0e-9 * float(sys.argv[2]) # m
    if nargs >= 4: meff = float(sys.argv[3])
    if nargs >= 5: Z = float(sys.argv[4])


def get_zeros(func, xmin = 0.0, xmax = 10.0, dx = 0.1, eps = 1.0e-10, nmaxiter = 50, h = 1.0e-10, dump = 1.0, print_level = 0):
    """
    概要:
        関数の零点を計算します。
    詳細説明:
        与えられた関数 func の指定された区間 [xmin, xmax] 内における零点を探索します。
        零点の探索には、関数値の符号反転を検出した後、ニュートン法に似た数値的な手法を使用します。
        各ステップで接線近似を用いて次点の推定を行い、指定された収束条件 eps または最大反復回数 nmaxiter に達するまで繰り返します。
    引数:
        :param func: 零点を探す関数。引数を1つ取るcallableオブジェクト。
        :type func: callable
        :param xmin: 探索範囲の最小値。
        :type xmin: float
        :param xmax: 探索範囲の最大値。
        :type xmax: float
        :param dx: 初期探索におけるステップサイズ。
        :type dx: float
        :param eps: 零点探索の収束判定に用いる許容誤差。
        :type eps: float
        :param nmaxiter: ニュートン法系反復の最大回数。
        :type nmaxiter: int
        :param h: 数値微分の計算に使用する微小な差分。
        :type h: float
        :param dump: ニュートン法におけるステップサイズの調整係数（ダンプファクター）。1.0は標準的なステップサイズ。
        :type dump: float
        :param print_level: デバッグ情報の出力レベル。0で非表示、1で表示。
        :type print_level: int
    戻り値:
        :returns: 見つかった零点のリスト。
        :rtype: list[float]
    """

    zeros = []
    nnode = 0
    prev_r = 0.0
    prev_f = 1.0
    for r in np.arange(xmin, xmax + dx, dx):
        f = func(r)

        if f == 0.0:
            zeros.append(r)
        elif r > 0.0 and f * prev_f < 0.0:
            r0 = prev_r
            r1 = r
            for i in range(nmaxiter):
                if print_level:
                    print(f" iter#{i}/{nmaxiter}: nnode={nnode} dr={r1:8.4g} - {r0:8.4g} = {r1-r0:8.4g}")
                if abs(r1 - r0) < eps: 
                    zeros.append(r1)
                    break

                r0 = r1
                f1 = func(r1)

                rh = r1 + h
                fh = func(rh)
                fdiff = (fh - f1) / (rh - r1)

                r1 = r1 - f1 / fdiff / dump

            nnode += 1

        prev_r = r
        prev_f = f

    return zeros

def get_bessel_zeros(l, xmin = 0.0, xmax = 10.0, dx = 0.1, print_level = 0):
    """
    概要:
        球Bessel関数 j_l(x) の零点を計算します。
    詳細説明:
        指定された次数 l の球Bessel関数 scipy.special.spherical_jn(l, x) を対象として、
        get_zeros 関数を用いて零点を探索します。
    引数:
        :param l: 球Bessel関数の次数 (軌道角運動量量子数)。
        :type l: int
        :param xmin: 探索範囲の最小値。
        :type xmin: float
        :param xmax: 探索範囲の最大値。
        :type xmax: float
        :param dx: 初期探索におけるステップサイズ。
        :type dx: float
        :param print_level: デバッグ情報の出力レベル。get_zeros 関数に渡されます。
        :type print_level: int
    戻り値:
        :returns: 球Bessel関数の零点のリスト。
        :rtype: list[float]
    """

    return get_zeros(lambda x: spherical_jn(l, x), xmin, xmax, dx, print_level = print_level)


def energy_level(meff, R, n, l, zeros):
    """
    概要:
        球対称量子井戸（量子ドット）のエネルギー準位を計算します。
    詳細説明:
        無限に深い球対称ポテンシャル井戸モデルに基づき、与えられた量子数 n と l に対応する
        エネルギー準位を計算します。エネルギーは、球Bessel関数の零点 alpha_nl を用いて
        E_nl = (hbar^2 * k_nl^2) / (2 * m_eff * m_e) の式で求められます。
        ここで k_nl = alpha_nl / R です。
    引数:
        :param meff: 電子の有効質量 (自由電子質量 me に対する比率)。
        :type meff: float
        :param R: 量子ドットの半径 (m)。
        :type R: float
        :param n: 主量子数 (1から始まる)。球Bessel関数のn番目の零点に対応。
        :type n: int
        :param l: 軌道角運動量量子数。
        :type l: int
        :param zeros: 球Bessel関数の零点のリスト。zeros[l][n-1] の形式で零点にアクセスします。
        :type zeros: list[list[float]]
    戻り値:
        :returns: 計算されたエネルギー準位 (eV) と対応する球Bessel関数の零点 alpha_nl のタプル。零点が見つからない場合は (None, None) を返します。
        :rtype: tuple[float, float] or tuple[None, None]
    """
    # 球Bessel関数のn番目の零点αnlを求め、波数ベクトルknlを計算
    if len(zeros[l]) <= n - 1: return None, None
    aR = zeros[l][n-1]

    knl = aR / R
    Enl = hbar * hbar * knl * knl / 2.0 / meff / me
    Enl_eV = Enl / eV
    
    return Enl_eV, aR


def cal():
    """
    概要:
        量子球のエネルギー準位を計算し、結果を標準出力に表示します。
    詳細説明:
        設定された有効質量 meff と半径 R を持つ量子球（量子ドット）に対し、
        指定された範囲の主量子数 n と軌道角運動量量子数 l について、
        そのエネルギー準位を計算します。
        まず get_bessel_zeros を用いて球Bessel関数の零点を取得し、
        次に energy_level を用いて各準位のエネルギーを計算します。
        計算結果はエネルギーの低い順にソートされ、各準位の情報（量子数、エネルギー、零点値）が出力されます。
        
    戻り値:
        :returns: なし
        :rtype: None
    """
    print()
    print("Energy levels for quantum sphere")
    print(f"  effective mass: {meff} me")
    print(f"  radius: {R*1.0e9} nm")

    nmax = 5
    zeros = []
    print("Zero points:")
    for l in range(0, nmax):
        zero_points = get_bessel_zeros(l, xmin = 0.1, xmax = 10.0, dx = 0.1, print_level = 0)
        zeros.append(zero_points)
        print(f"l={l}:", [float(v) for v in zero_points])

    level_list = []
    for n in range(1, nmax):
       for l in range(0, nmax):
            Enl, R_zero = energy_level(meff, R, n, l, zeros)
            if Enl is None: continue

            level_list.append({"E": Enl, "label": f"{n+l}{lstr[l]}", "n": n, "l": l, 
                                "R0": R_zero})

    print(f"l n l+n  E(orb)")
    for d in sorted(level_list, key=lambda x: x["E"]):
        E     = d["E"]
        label = d["label"]
        n     = d["n"]
        l     = d["l"]
        R0 = d["R0"]
        print(f"{l} {n} {l+n}    E({label})={E:12.8g} eV  alpha_{n}_{l}={R0:12.8g}  alpha_{n}_{l}^2={R0*R0:12.8g}")
    print(f"see {url} for the definitions of quantum numbers")

def plot_spherical_bessel(lmax, rmax, rmesh = 500):
    """
    概要:
        球Bessel関数 j_l(x) をプロットします。
    詳細説明:
        0から lmax までの次数 l について、球Bessel関数 j_l(x) を
        0 から rmax までの区間で計算し、matplotlib を用いてグラフを描画します。
        各関数の零点も get_bessel_zeros を使用して計算し、グラフ上にマークします。
    引数:
        :param lmax: プロットする球Bessel関数の最大次数 (l)。
        :type lmax: int
        :param rmax: r軸の最大値。
        :type rmax: float
        :param rmesh: r軸のデータポイント数。
        :type rmesh: int
    戻り値:
        :returns: なし
        :rtype: None
    """

    plt.figure(figsize=(8, 6))
    plt.title(f"Spherical Bessel Function $j_l(kr)$", fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    r = np.linspace(0, rmax, rmesh)
    for l in range(lmax + 1):
        j_l = spherical_jn(l, r)
        zero_points = get_bessel_zeros(l, xmin = 0.1, xmax = 10.0, dx = 0.1, print_level = 1)
        print("zeros: l=", l, zero_points)

        plt.plot(r, j_l, label = f"$j_{l}(kr)$")
        plt.plot(zero_points, [0.0] * len(zero_points), label = f"zero points of $j_{l}(kr)$", linestyle = '', marker = 'o')

    plt.axhline(y=0, color='red', linestyle='--', linewidth = 0.5)
    plt.axvline(x=0, color='red', linestyle='--', linewidth = 0.5)
    plt.xlabel(r"$k$$\cdot$$r$", fontsize=16)
    plt.ylabel(r"$j_l$($k$$\cdot$$r$)", fontsize=16)
#    plt.grid(True)

    plt.legend(fontsize=12)
    
    plt.show()

def plot_H(nmax, rmax):
    """
    概要:
        水素原子の動径波動関数 R_nl(r) をプロットします。
    詳細説明:
        水素原子の動径波動関数 R_nl(r) を、主量子数 n と軌道角運動量量子数 l の
        組み合わせに対して計算し、matplotlib を用いてグラフを描画します。
        genlaguerre 関数（一般化されたラゲール多項式）を用いて計算が行われます。
        プロットされる動径 r の範囲は 0 から 20 まで（ボーア半径 a0 単位）。
        なお、rmax 引数は現在の実装ではプロット範囲に影響を与えません。
    引数:
        :param nmax: プロットする主量子数 n の最大値。n は 1 から nmax まで。
        :type nmax: int
        :param rmax: (未使用) r軸の最大値として想定される値（ボーア半径 a0 単位）。
        :type rmax: float
    戻り値:
        :returns: なし
        :rtype: None
    """
    a0 = 1.0

    plt.figure(figsize=(8, 6))
    plt.title("Radial Wave Function $R_{20}(r)$ for Hydrogen Atom", fontsize=16)
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)

    r = np.linspace(0, 20, 500)
    for n in range(nmax + 1):
        for l in range(n):
            rho = 2 * Z * r / (n * a0)
            Lnl = genlaguerre(n - l - 1, 2 * l + 1)(rho)
            R_nl = (rho ** l) * np.exp(-rho / 2) * Lnl
#            R_nl = sqrt((2 * Z / (n * a0))**3 \
#                * factorial(n - l - 1) / (2 * n * factorial(n + l))) \
#                * R_nl
            plt.plot(r, R_nl, label = f"$R_{n}$$_{l}$($r$)")

    plt.axhline(y=0, color='red', linestyle='--', linewidth = 0.5)
    plt.axvline(x=0, color='red', linestyle='--', linewidth = 0.5)
    plt.xlabel("r (in units of $a_0$)", fontsize=16)
    plt.ylabel("$R_{nl}$($r$)", fontsize=16)
#    plt.grid(True)

    plt.legend(fontsize=12)
    
    plt.show()

def main():
    """
    概要:
        スクリプトのメイン実行関数。
    詳細説明:
        コマンドライン引数 sys.argv[1] の値に基づいて、実行モードを決定します。
        'cal' モードでは cal() 関数を呼び出し、量子球のエネルギー準位計算と表示を実行します。
        'plot' モードでは plot_spherical_bessel() 関数を呼び出し、球Bessel関数のプロットを実行します。
        'plotH' モードでは plot_H() 関数を呼び出し、水素原子の動径波動関数のプロットを実行します。
        上記以外のモードが指定された場合はエラーメッセージを表示し、スクリプトを終了します。
        
    戻り値:
        :returns: なし
        :rtype: None
    """
    if mode == 'cal':
        cal()
    elif mode == 'plot':
        plot_spherical_bessel(lmax = 3, rmax = 10.0)
    elif mode == 'plotH':
        plot_H(nmax = 3, rmax = 10.0)
    else:
        print(f"Error: Invalid mode={mode}")
        exit()

if __name__ == '__main__':
    main()