"""
概要: シュレーディンガー方程式の固有エネルギーをBrent法で精密化するスクリプト。

詳細説明:
本スクリプトは、schrodinger1dモジュールを使用してシュレーディンガー方程式を解き、
その結果に基づいて境界条件が満たされない残差関数を定義します。
与えられた初期エネルギー値から出発し、残差関数の符号反転区間を自動的に探索します。
その後、SciPyライブラリのBrent法 (scipy.optimize.brentq) を利用して、
残差がゼロとなるエネルギー（固有エネルギー）を高精度で求めます。
精密化された固有エネルギーと計算の詳細は、refineE.csvファイルに記録されます。

関連リンク:
refineE_schrodinger1d_usage
"""
import os
import argparse
import numpy as np
from scipy.optimize import brentq
from schrodinger1d import solve_schrodinger

# ----------------------------------------
# residual(E): ψ(L) を返す関数
# ----------------------------------------
def residual(E, args, print_level = 0):
    """
    概要: シュレーディンガー方程式を解き、波動関数の右端での値 (psi(L)) を返す。

    詳細説明:
    schrodinger1d.solve_schrodinger 関数を呼び出し、与えられたエネルギー E
    に対するシュレーディンガー方程式の波動関数を計算する。
    この関数は、Brent法などの根探索アルゴリズムで使用される残差関数として機能し、
    波動関数の右端 (x=L) での値がゼロになるエネルギーを探索するために用いられる。

    引数:
        :param E: 試行するエネルギー値。
        :type E: float
        :param args: main 関数で定義されたコマンドライン引数を格納するオブジェクト。
                     シュレーディンガー方程式のパラメータ (L, nx, psi_max, bc, eps) を含む。
        :type args: argparse.Namespace
        :param print_level: 0以外の場合、計算されたエネルギーと psi(L) の値を出力する。
        :type print_level: int
    戻り値:
        :returns: 計算された波動関数の右端 (x=L) での値。
        :rtype: float
    """
    x, psi, info = solve_schrodinger(
        E=E,
        L=args.L,
        nx=args.nx,
        psi_max=args.psi_max,
        report_step=0,
        bc=args.bc,
        eps=args.eps
    )
    
    if print_level:
        print(f"E={E:12.8g}  psi(L)={psi[-1]:14.6e}")

    return psi[-1]  # ψ(L)


# ----------------------------------------
# 初期値 E0 から符号反転区間を自動探索
# ----------------------------------------
def find_bracket(E0, args, max_expand=20, direction='both'):
    """
    概要: 与えられた初期エネルギー E0 を中心に、残差関数の符号が反転するエネルギー区間を探索する。

    詳細説明:
    E0 を出発点とし、初期の delta 値を用いて E0 - delta と E0 + delta
    の範囲で residual 関数の符号をチェックする。
    もし符号反転が見つからない場合、delta を2倍にして探索範囲を広げ、
    max_expand 回数までこのプロセスを繰り返す。
    符号反転区間が見つかった場合、その下限 E_low と上限 E_high を返す。

    引数:
        :param E0: 探索を開始する初期エネルギー値。
        :type E0: float
        :param args: residual 関数に渡すシュレーディンガー方程式のパラメータ。
        :type args: argparse.Namespace
        :param max_expand: 探索範囲を拡大する最大回数。この回数を超えても符号反転区間が見つからない場合はエラーとなる。
        :type max_expand: int
        :param direction: 探索する方向 ('upper', 'lower', 'both')。
                          'upper': E0より大きい側のみ探索。
                          'lower': E0より小さい側のみ探索。
                          'both': E0を挟んで両側を探索。
        :type direction: str
    戻り値:
        :returns: 符号反転区間の下限 E_low と上限 E_high のタプル。
        :rtype: tuple[float, float]
    例外:
        :raises ValueError: direction が 'upper', 'lower', 'both' のいずれでもない場合。
        :raises RuntimeError: max_expand 回数内に符号反転区間が見つからなかった場合。
    """
    delta = 0.05 * abs(E0) + 0.01

    # 初期値の符号
    f0 = residual(E0, args)

    for _ in range(max_expand):

        if direction == 'upper':
            E_low  = E0
            E_high = E0 + delta

        elif direction == 'lower':
            E_low  = E0 - delta
            E_high = E0

        elif direction == 'both':
            E_low  = E0 - delta
            E_high = E0 + delta

        else:
            raise ValueError("direction must be 'upper', 'lower', or 'both'")

        f_low  = residual(E_low, args)
        f_high = residual(E_high, args)

        if f_low * f_high < 0:
            print(f"[INFO] Found bracket: [{E_low}, {E_high}]")
            return E_low, E_high

        delta *= 2

    raise RuntimeError("Could not find sign change interval.")

# ----------------------------------------
# Brent 法の maxiter を自動計算
# ----------------------------------------
def compute_maxiter(E_low, E_high, E_tol):
    """
    概要: Brent法の最大反復回数を自動的に計算する。

    詳細説明:
    Brent法は、探索区間の幅を指数関数的に減少させる。
    この関数は、初期区間 [E_low, E_high] の幅と目標精度 E_tol に基づいて、
    必要な反復回数を推定する。これにより、不要な計算を避けつつ、
    十分な精度に達するための反復回数を設定できる。
    経験的に5回の余裕を持たせている。

    引数:
        :param E_low: 符号反転区間の下限。
        :type E_low: float
        :param E_high: 符号反転区間の上限。
        :type E_high: float
        :param E_tol: ターゲットとするエネルギーの許容誤差。
        :type E_tol: float
    戻り値:
        :returns: Brent法に推奨される最大反復回数。
        :rtype: int
    """
    width = abs(E_high - E_low)
    return int(np.ceil(np.log2(width / E_tol))) + 5


# ----------------------------------------
# main
# ----------------------------------------
def main():
    """
    概要: コマンドライン引数を受け取り、Brent法を用いてシュレーディンガー方程式の固有エネルギーを精密化する。

    詳細説明:
    このメイン関数は以下の手順で動作する:
    1.  argparse を用いてコマンドライン引数を解析し、初期エネルギー E0、許容誤差 E_tol、
        およびシュレーディンガー方程式の物理パラメータなどを取得する。
    2.  find_bracket 関数を呼び出し、与えられた初期エネルギー args.E を含む符号反転区間
        [E_low, E_high] を探索する。
    3.  もし maxiter がコマンドラインで指定されていない場合、compute_maxiter 関数を用いて、
        目標精度に基づいて自動的に最大反復回数を設定する。
    4.  scipy.optimize.brentq 関数を使用して、residual 関数がゼロとなるエネルギー値
        （すなわち固有エネルギー）を E_low と E_high の間で精密に探索する。
        探索中には residual 関数の呼び出しごとに途中経過が出力される。
    5.  精密化された固有エネルギーと、その探索結果（反復回数、関数呼び出し回数など）を出力する。
    6.  最終的な結果は refineE.csv ファイルに追記される。ファイルが存在しない場合は、
        最初にヘッダー行が追加される。
    """
    parser = argparse.ArgumentParser(
        description="Refine eigenvalue using Brent method"
    )

    parser.add_argument("--E", type=float, required=True,
                        help="Initial guess for eigenvalue")
    parser.add_argument("--E_tol", type=float, default=1e-6,
                        help="Target precision for eigenvalue")
    parser.add_argument("--maxiter", type=int, default=None,
                        help="Max iterations for Brent method (optional)")

    # schrodinger1d.py と同じ引数を流用
    parser.add_argument("--L", type=float, default=10.0)
    parser.add_argument("--nx", type=int, default=2000)
    parser.add_argument("--psi_max", type=float, default=1e10)
    parser.add_argument("--bc", choices=["asymptotic", "zero"], default="asymptotic")
    parser.add_argument("--eps", type=float, default=1e-6)

    args = parser.parse_args()

    # 1. 初期値 E0 から符号反転区間を探す
    E_low, E_high = find_bracket(args.E, args, direction = 'upper')

    # 2. maxiter を自動設定
    if args.maxiter is None:
        args.maxiter = compute_maxiter(E_low, E_high, args.E_tol)
        print(f"[INFO] maxiter set to {args.maxiter}")

    # 3. Brent 法で固有値を精密化
    E_root, result = brentq(
        lambda E: residual(E, args, print_level=1),
        E_low, E_high,
        xtol=args.E_tol,
        maxiter=args.maxiter,
        full_output=True
    )

    print("====================================")
    print(" Refined eigenvalue:")
    print(f"   E = {E_root:.12f}")
    print("====================================")
    
    # f_low, f_high を計算
    f_low  = residual(E_low, args)
    f_high = residual(E_high, args)

    log_csv = "refineE.csv"
    print()
    if not os.path.exists(log_csv):
        print(f"{log_csv} does not exist. Add header labels.")
        with open(log_csv, "a") as fp:
            fp.write("E0, E_low, E_high, E_root, f_low, f_high, iterations, funcalls\n")

    print(f"Append the result to {log_csv}")
    with open(log_csv, "a") as fp:
        fp.write(
            f"{args.E},"
            f"{E_low},{E_high},"
            f"{E_root},"
            f"{f_low},{f_high},"
            f"{result.iterations},{result.function_calls}\n"
        )


if __name__ == "__main__":
    main()