"""
概要: 1次元シュレーディンガー方程式ソルバー

詳細説明:
    このスクリプトは、1次元の定常シュレーディンガー方程式を射撃法を用いて数値的に解きます。
    原子単位系 (atomic units) を使用し、Verlet積分を用いて波動関数を計算します。
    境界条件として「漸近条件 (asymptotic)」または「ゼロ条件 (zero)」を選択でき、
    結果はExcelファイルに保存され、matplotlibで可視化されます。

関連リンク:
    schrodinger1d_usage
"""
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# =========================
# Potential (atomic units)
# =========================
def V(x):
    """
    概要: ポテンシャルエネルギーを計算します。

    詳細説明:
        現在の実装では、調和振動子ポテンシャル V(x) = 0.5 * x**2 (Ha) を使用しています。
        他のポテンシャルを使用する場合は、この関数を変更してください。

    引数:
        :param x: 位置 (原子単位)
        :type x: float or numpy.ndarray
    戻り値:
        :returns: ポテンシャルエネルギー V(x) (Ha)
        :rtype: float or numpy.ndarray
    """
    return 0.5 * x**2
#    return 0.5 * x**2 + 1e-2 * x**4


# =========================
# Argument initialization
# =========================
def initialize():
    """
    概要: コマンドライン引数を初期化し、パースします。

    詳細説明:
        argparse モジュールを使用して、エネルギーの推定値、計算領域、メッシュ点数、
        波動関数の発散閾値、レポート間隔、出力ファイル名、境界条件、初期微分値といった
        パラメータを定義します。

    戻り値:
        :returns: argparse.ArgumentParserオブジェクトとargparse.Namespaceオブジェクトを含むタプルです。
        :rtype: tuple[argparse.ArgumentParser, argparse.Namespace]
    """
    parser = argparse.ArgumentParser(
        description="Solve 1D Schrodinger equation by shooting method (atomic units)"
    )

    parser.add_argument("--E", type=float, required=True,
                        help="Energy eigenvalue guess (Ha)")
    parser.add_argument("--L", type=float, default=10.0,
                        help="Half domain size [-L, L]")
    parser.add_argument("--nx", type=int, default=2000,
                        help="Number of mesh points")
    parser.add_argument("--psi_max", type=float, default=1e10,
                        help="Divergence threshold for psi")
    parser.add_argument("--report_step", type=int, default=200,
                        help="Interval for progress report")
    parser.add_argument("--outfile", type=str, default="schrodinger.xlsx",
                        help="Output Excel filename")

    # --- boundary condition ---
    parser.add_argument("--bc", choices=["asymptotic", "zero"],
                        default="asymptotic",
                        help="Boundary condition at x=-L")

    # used only for bc=zero
    parser.add_argument("--eps", type=float, default=1e-6,
                        help="Initial derivative dpsi/dx at x=-L (bc=zero only)")

    args = parser.parse_args()
    return parser, args


# =========================
# Schrodinger solver
# =========================
def solve_schrodinger(E, L, nx, psi_max, report_step, bc, eps):
    """
    概要: 1次元シュレーディンガー方程式を射撃法で解きます。

    詳細説明:
        Verlet積分を用いて波動関数を数値的に伝播させます。
        指定された境界条件 (bc) に基づいて初期条件を設定し、
        計算中に波動関数が psi_max を超えた場合、発散と判断して計算を中断します。
        計算の進行状況は report_step 間隔で出力されます。

    引数:
        :param E: エネルギー固有値の推定値 (Ha)
        :type E: float
        :param L: 計算領域の半分のサイズ。xは[-L, L]の範囲になります。
        :type L: float
        :param nx: メッシュ点の数
        :type nx: int
        :param psi_max: 波動関数の発散を検出するための閾値
        :type psi_max: float
        :param report_step: 計算の進行状況を報告するステップ間隔。0以下の場合は報告しません。
        :type report_step: int
        :param bc: 境界条件のタイプ ("asymptotic" または "zero")
        :type bc: str
        :param eps: bc="zero" の場合に使用される、x=-Lにおける波動関数の初期微分値
        :type eps: float
    戻り値:
        :returns:
            x座標の配列 x_full、波動関数の配列 psi_full、および計算情報を含む辞書 info のタプルを返します。
            info辞書には、"success" (bool)、"diverged" (bool)、"diverged_x" (float or None)、
            "message" (str) のキーが含まれています。
        :rtype: tuple[numpy.ndarray, numpy.ndarray, dict]
    """
    x_full = np.linspace(-L, L, nx)
    dx = x_full[1] - x_full[0]

    psi_full = np.zeros(nx)

    # =========================
    # Initial conditions
    # =========================
    if bc == "zero":
        # ψ(-L)=0, ψ'(-L)=eps
        psi_full[0] = 0.0
        psi_full[1] = eps * dx

    elif bc == "asymptotic":
        # ψ'(-L) = κ ψ(-L)
        psi0 = 1e-6
        kappa = np.sqrt(2.0 * (V(-L) - E))

        psi_full[0] = psi0
        psi_full[1] = psi0 * (1.0 + kappa * dx)

    else:
        raise ValueError("Unknown boundary condition")

    diverged_x = None
    diverge_index = None

    # =========================
    # Verlet integration
    # =========================
    for i in range(1, nx - 1):
        k = 2.0 * (V(x_full[i]) - E)
        psi_full[i + 1] = (
            2.0 * psi_full[i]
            - psi_full[i - 1]
            + dx**2 * k * psi_full[i]
        )

        if abs(psi_full[i + 1]) > psi_max:
            diverge_index = i + 1
            diverged_x = x_full[i + 1]
            break

        if report_step > 0 and i % report_step == 0:
            print(f"[INFO] step={i}, x={x_full[i]:.3f}, psi={psi_full[i]:.3e}")

    info = {}

    if diverge_index is None:
        info["success"] = True
        info["diverged"] = False
        info["diverged_x"] = None
        info["message"] = "integration completed successfully"
        return x_full, psi_full, info
    else:
        info["success"] = False
        info["diverged"] = True
        info["diverged_x"] = diverged_x
        info["message"] = f"diverged at x={diverged_x:.3f}"
        return x_full[:diverge_index], psi_full[:diverge_index], info


# =========================
# Save & plot
# =========================
def save_and_plot(x, psi, outfile, diverged=False, diverged_x=None):
    """
    概要: 計算結果の波動関数データをExcelファイルに保存し、matplotlibでプロットします。

    詳細説明:
        波動関数 psi と確率密度 abs(psi)**2 をExcelファイル (outfile) に保存します。
        matplotlibを用いて psi と abs(psi)**2 をグラフ表示します。
        波動関数が発散した場合は、発散が始まったx座標に垂直な破線が追加され、y軸が対数スケールになります。

    引数:
        :param x: x座標の配列
        :type x: numpy.ndarray
        :param psi: 計算された波動関数の配列
        :type psi: numpy.ndarray
        :param outfile: データ保存先のExcelファイル名
        :type outfile: str
        :param diverged: 波動関数が計算中に発散したかどうかを示すフラグ。デフォルトはFalse。
        :type diverged: bool
        :param diverged_x: 波動関数が発散し始めたx座標。diverged がTrueの場合にのみ使用されます。
                           デフォルトはNone。
        :type diverged_x: float or None
    戻り値:
        :returns: なし
        :rtype: None
    """
    df = pd.DataFrame({
        "x": x,
        "psi": psi,
        "psi2": psi**2
    })
    df.to_excel(outfile, index=False)
    print(f"[INFO] Data saved to {outfile}")

    plt.figure(figsize=(8, 5))
    if not diverged:
        plt.plot(x, psi, label=r"$\psi(x)$")
    plt.plot(x, psi**2, label=r"$|\psi(x)|^2$")

    if diverged and diverged_x is not None:
        plt.axvline(diverged_x, color="r", linestyle="--",
                    label="divergence start")

    plt.xlabel("x (a.u.)")
    plt.ylabel("Wavefunction")
    if diverged:
        plt.yscale("log")
    plt.legend()
    plt.grid()
    plt.tight_layout()
    plt.show()


# =========================
# Main
# =========================
def main():
    """
    概要: プログラムのメイン処理を実行します。

    詳細説明:
        本関数は、まずコマンドライン引数をパースして初期設定を行います。
        次に、パースされた引数を使用して solve_schrodinger 関数を呼び出し、
        1次元シュレーディンガー方程式を解きます。
        その後、計算結果のメッセージと終了点の波動関数値を出力し、
        最後に save_and_plot 関数を呼び出して結果をExcelに保存し、グラフを表示します。

    戻り値:
        :returns: なし
        :rtype: None
    """
    _, args = initialize()

    print("=== 1D Schrodinger Solver (atomic units) ===")
    print(f"E = {args.E} Ha,  L = {args.L},  nx = {args.nx}")
    print(f"Boundary condition: {args.bc}")

    x, psi, info = solve_schrodinger(
        E=args.E,
        L=args.L,
        nx=args.nx,
        psi_max=args.psi_max,
        report_step=args.report_step,
        bc=args.bc,
        eps=args.eps
    )

    print("[RESULT]", info["message"])
    print(f"psi(end) = {psi[-1]:.3e}")

    save_and_plot(
        x,
        psi,
        args.outfile,
        diverged=info["diverged"],
        diverged_x=info["diverged_x"]
    )


if __name__ == "__main__":
    main()