import argparse
import sys
import traceback
from types import SimpleNamespace

import matplotlib.pyplot as plt
import numpy as np
from numpy import cos, cosh, exp, sin, sinh, sqrt
import numpy.linalg as LA


"""
Kronig-Penneyモデルによる1次元バンド計算スクリプト。

概要:
    矩形ポテンシャルを用いたKronig-Penneyモデルに基づき、電子のエネルギーバンド構造、
    波動関数、およびデルタ関数グラフを計算しプロットします。
詳細説明:
    このスクリプトは、以下の3つの主要なモードで動作します。
    1. delta(E) グラフのプロット (graphモード)
    2. バンド構造のプロット (bandモード)
    3. 波動関数のプロット (wfモード)
    各モードでは、ポテンシャルの幅、高さ、格子定数などのパラメータを調整できます。
主な機能:
    - Kronig-Penney方程式の解に基づくエネルギーデルタ関数の計算。
    - エネルギーバンドの計算とプロット。
    - 選択されたエネルギー準位に対応する波動関数の計算とプロット。
関連リンク:
    kronig_penney_refactored_usage
"""


# =============================================================================
# Constants
# =============================================================================
PI = 3.14159265358979323846
PI2 = 2.0 * PI

H = 6.6260755e-34
HBAR = 1.05459e-34
ELEMENTARY_CHARGE = 1.60218e-19
ELECTRON_MASS = 9.1093897e-31

DEFAULT = SimpleNamespace(
    mode="graph",
    a=5.4064,          # A
    bwidth=0.5,        # A
    bpot=10.0,         # eV
    kg=0.0,            # pi/a
    emin=0.0,          # eV
    emax=9.5,          # eV
    nE=51,
    nEsearch=51,
    eps=1.0e-8,
    nmaxiter=100,
    dump=0.0,
    kmin=-0.5,         # pi/a
    kmax=0.5,          # pi/a
    nk=21,
    erange_min=0.0,    # eV
    erange_max=10.0,   # eV
    nMaxLevel=15,
    xwmin=0.0,         # A
    xwmax=None,        # A; None -> 3.0 * a
    nxw=101,
    kw=0.0,            # pi/a
    iLevel=0,
    figsize_w=6.0,
    figsize_h=8.0,
    wf_figsize_w=16.0,
    wf_figsize_h=4.0,
    fontsize=12,
    legend_fontsize=8,
    show=1,
    save=0,
    output="kronig_penney.png",
)


# =============================================================================
# argparse
# =============================================================================
def parse_args() -> argparse.Namespace:
    """
    概要:
        コマンドライン引数を解析します。
    詳細説明:
        Kronig-Penneyモデルの計算およびプロットに必要なパラメータをコマンドラインから受け取ります。
        従来の引数形式もサポートしています。xwmax が指定されない場合は、格子定数 a の3倍がデフォルト値となります。
    引数:
        なし。
    戻り値:
        :returns: 解析されたコマンドライン引数を格納したNamespaceオブジェクト。
        :rtype: argparse.Namespace
    """
    parser = argparse.ArgumentParser(
        description="Kronig-Penney model: graph, band, and wavefunction plotter.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # 従来の位置引数形式も維持するため、mode と legacy_args を受ける。
    parser.add_argument("mode", nargs="?", default=DEFAULT.mode, choices=["graph", "band", "wf"])
    parser.add_argument("legacy_args", nargs="*")

    parser.add_argument("--a", type=float, default=DEFAULT.a)
    parser.add_argument("--bwidth", type=float, default=DEFAULT.bwidth)
    parser.add_argument("--bpot", type=float, default=DEFAULT.bpot)

    parser.add_argument("--kg", type=float, default=DEFAULT.kg)
    parser.add_argument("--Emin", dest="emin", type=float, default=DEFAULT.emin)
    parser.add_argument("--Emax", dest="emax", type=float, default=DEFAULT.emax)
    parser.add_argument("--nE", type=int, default=DEFAULT.nE)

    parser.add_argument("--nEsearch", type=int, default=DEFAULT.nEsearch)
    parser.add_argument("--eps", type=float, default=DEFAULT.eps)
    parser.add_argument("--nmaxiter", type=int, default=DEFAULT.nmaxiter)
    parser.add_argument("--dump", type=float, default=DEFAULT.dump)

    parser.add_argument("--kmin", type=float, default=DEFAULT.kmin)
    parser.add_argument("--kmax", type=float, default=DEFAULT.kmax)
    parser.add_argument("--nk", type=int, default=DEFAULT.nk)

    parser.add_argument("--ErangeMin", dest="erange_min", type=float, default=DEFAULT.erange_min)
    parser.add_argument("--ErangeMax", dest="erange_max", type=float, default=DEFAULT.erange_max)
    parser.add_argument("--nMaxLevel", type=int, default=DEFAULT.nMaxLevel)

    parser.add_argument("--kw", type=float, default=DEFAULT.kw)
    parser.add_argument("--iLevel", type=int, default=DEFAULT.iLevel)
    parser.add_argument("--xwmin", type=float, default=DEFAULT.xwmin)
    parser.add_argument("--xwmax", type=float, default=DEFAULT.xwmax)
    parser.add_argument("--nxw", type=int, default=DEFAULT.nxw)

    parser.add_argument("--figsizeW", dest="figsize_w", type=float, default=DEFAULT.figsize_w)
    parser.add_argument("--figsizeH", dest="figsize_h", type=float, default=DEFAULT.figsize_h)
    parser.add_argument("--wfFigsizeW", dest="wf_figsize_w", type=float, default=DEFAULT.wf_figsize_w)
    parser.add_argument("--wfFigsizeH", dest="wf_figsize_h", type=float, default=DEFAULT.wf_figsize_h)
    parser.add_argument("--fontsize", type=int, default=DEFAULT.fontsize)
    parser.add_argument("--legendFontsize", dest="legend_fontsize", type=int, default=DEFAULT.legend_fontsize)

    parser.add_argument("--show", type=int, choices=[0, 1], default=DEFAULT.show)
    parser.add_argument("--save", type=int, choices=[0, 1], default=DEFAULT.save)
    parser.add_argument("--output", type=str, default=DEFAULT.output)

    args = parser.parse_args()
    apply_legacy_args(args)
    if args.xwmax is None:
        args.xwmax = 3.0 * args.a
    args.nEsearch = args.nE if args.nEsearch == DEFAULT.nEsearch else args.nEsearch
    return args


def apply_legacy_args(args: argparse.Namespace) -> None:
    """
    概要:
        旧形式のコマンドライン引数を適用します。
    詳細説明:
        modeの後に続く位置引数として指定された値を解析し、argsオブジェクトの対応する属性に設定します。
        これにより、スクリプトの古いバージョンとの互換性が保たれます。
    引数:
        :param args: コマンドライン引数を格納したNamespaceオブジェクト。
        :type args: argparse.Namespace
    戻り値:
        なし。
    """
    values = args.legacy_args
    if not values:
        return

    if len(values) >= 1:
        args.a = float(values[0])
    if len(values) >= 2:
        args.bwidth = float(values[1])
    if len(values) >= 3:
        args.bpot = float(values[2])

    if args.mode == "graph":
        if len(values) >= 4:
            args.kg = float(values[3])
        if len(values) >= 5:
            args.emin = float(values[4])
        if len(values) >= 6:
            args.emax = float(values[5])
        if len(values) >= 7:
            args.nE = int(values[6])
    elif args.mode == "band":
        if len(values) >= 4:
            args.kmin = float(values[3])
        if len(values) >= 5:
            args.kmax = float(values[4])
        if len(values) >= 6:
            args.nk = int(values[5])
    elif args.mode == "wf":
        if len(values) >= 4:
            args.kw = float(values[3])
        if len(values) >= 5:
            args.iLevel = int(values[4])
        if len(values) >= 6:
            args.xwmin = float(values[5])
        if len(values) >= 7:
            args.xwmax = float(values[6])
        if len(values) >= 8:
            args.nxw = int(values[7])


# =============================================================================
# utility functions
# =============================================================================
def round01(x: float, a: float) -> tuple[float, int]:
    """
    概要:
        値を周期 a で区間 [0, a) に丸め、周期数を返します。
    詳細説明:
        入力値 x を周期 a で正規化し、周期内の値 x0 と、何周期分シフトしたかを示す整数 n を計算します。
        x >= 0 の場合は n = floor(x / a)、x < 0 の場合は n = floor(x / a) - 1 となります。
        結果として x0 は常に [0, a) の範囲に収まります。
    引数:
        :param x: 丸める対象の浮動小数点数。
        :type x: float
        :param a: 周期を表す浮動小数点数。
        :type a: float
    戻り値:
        :returns: 周期内の値 x0 と周期数 n のタプル。
        :rtype: tuple
    """
    if x >= 0.0:
        n = int(x / a)
    else:
        n = int(x / a) - 1
    x0 = x - n * a
    return x0, n


def validate_args(args: argparse.Namespace) -> None:
    """
    概要:
        コマンドライン引数の値を検証します。
    詳細説明:
        各引数が物理的に妥当な範囲内にあるか、または計算に必要な最小値を満たしているかを確認します。
        例えば、幅やポテンシャル高さが正であること、格子定数 a がポテンシャル幅 bwidth より大きいことなどをチェックします。
    引数:
        :param args: コマンドライン引数を格納したNamespaceオブジェクト。
        :type args: argparse.Namespace
    戻り値:
        なし。
    例外:
        :raises ValueError: 引数の値が不正な場合に発生します。
    """
    if args.bwidth <= 0.0:
        raise ValueError("--bwidth must be positive.")
    if args.a <= args.bwidth:
        raise ValueError("--a must be larger than --bwidth.")
    if args.bpot <= 0.0:
        raise ValueError("--bpot must be positive.")
    if args.nE < 2:
        raise ValueError("--nE must be >= 2.")
    if args.nEsearch < 2:
        raise ValueError("--nEsearch must be >= 2.")
    if args.nk < 2:
        raise ValueError("--nk must be >= 2.")
    if args.nxw < 2:
        raise ValueError("--nxw must be >= 2.")
    if args.mode == "wf" and args.iLevel < 0:
        raise ValueError("--iLevel must be >= 0.")


def create_context(args: argparse.Namespace) -> SimpleNamespace:
    """
    概要:
        解析された引数からコンテキストオブジェクトを作成します。
    詳細説明:
        argparse.Namespaceオブジェクトの引数を元に、計算やプロットに必要な全てのパラメータを
        SimpleNamespaceオブジェクトとして集約します。これにより、関数の引数渡しが簡潔になります。
        特に、ポテンシャルの幅 b と井戸の幅 w を計算して格納します。
    引数:
        :param args: コマンドライン引数を格納したNamespaceオブジェクト。
        :type args: argparse.Namespace
    戻り値:
        :returns: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :rtype: types.SimpleNamespace
    """
    return SimpleNamespace(
        mode=args.mode,
        a=args.a,
        bwidth=args.bwidth,
        bpot=args.bpot,
        w=args.a - args.bwidth,
        b=args.bwidth,
        V0=args.bpot,
        kg=args.kg,
        emin=args.emin,
        emax=args.emax,
        nE=args.nE,
        nEsearch=args.nEsearch,
        eps=args.eps,
        nmaxiter=args.nmaxiter,
        dump=args.dump,
        kmin=args.kmin,
        kmax=args.kmax,
        nk=args.nk,
        erange=(args.erange_min, args.erange_max),
        nMaxLevel=args.nMaxLevel,
        kw=args.kw,
        iLevel=args.iLevel,
        xwmin=args.xwmin,
        xwmax=args.xwmax,
        nxw=args.nxw,
        figsize=(args.figsize_w, args.figsize_h),
        wf_figsize=(args.wf_figsize_w, args.wf_figsize_h),
        fontsize=args.fontsize,
        legend_fontsize=args.legend_fontsize,
        show=args.show,
        save=args.save,
        output=args.output,
    )


def save_show_close(fig: plt.Figure, ctx: SimpleNamespace) -> None:
    """
    概要:
        Matplotlibの図を保存、表示、閉じます。
    詳細説明:
        コンテキストオブジェクトの設定 (ctx.save, ctx.show, ctx.output) に応じて、
        生成されたMatplotlibの図 (Figureオブジェクト) をファイルに保存し、画面に表示し、
        最終的に閉じます。
    引数:
        :param fig: 保存、表示、閉じる対象のMatplotlib Figureオブジェクト。
        :type fig: matplotlib.pyplot.Figure
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        なし。
    """
    if ctx.save:
        fig.savefig(ctx.output, dpi=300, bbox_inches="tight")
        print(f"Saved figure: {ctx.output}")
    if ctx.show:
        plt.show()
    plt.close(fig)


# =============================================================================
# core functions
# =============================================================================
def compute_potential_value(x: float, ctx: SimpleNamespace) -> float:
    """
    概要:
        指定された位置 x におけるポテンシャルの値を計算します。
    詳細説明:
        Kronig-Penneyモデルにおける矩形ポテンシャル関数を定義します。
        x が周期 a のセル内でポテンシャル幅 bwidth の範囲内にある場合、ポテンシャルの高さ ctx.bpot を返し、
        それ以外の場合は 0.0 を返します。
    引数:
        :param x: ポテンシャルの値を評価するx座標。
        :type x: float
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: xにおけるポテンシャルの値。
        :rtype: float
    """
    xred, _ = round01(x, ctx.a)
    if ctx.a - ctx.bwidth <= xred < ctx.a:
        return ctx.bpot
    return 0.0


def compute_potential_profile(xmin: float, xstep: float, n: int, ctx: SimpleNamespace) -> tuple[np.ndarray, np.ndarray]:
    """
    概要:
        指定された範囲とステップでポテンシャルプロファイルを計算します。
    詳細説明:
        xminから始まり、xstep間隔でn個の点におけるポテンシャル値を計算し、x座標配列とポテンシャル値配列を生成します。
        各点のポテンシャル値は compute_potential_value を呼び出して決定されます。
    引数:
        :param xmin: プロファイルの開始x座標。
        :type xmin: float
        :param xstep: 各点間のxステップサイズ。
        :type xstep: float
        :param n: 計算する点の総数。
        :type n: int
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: x座標のnumpy.ndarrayと対応するポテンシャル値のnumpy.ndarrayのタプル。
        :rtype: tuple
    """
    xpot = xmin + np.arange(n) * xstep
    ypot = np.array([compute_potential_value(x, ctx) for x in xpot], dtype=float)
    return xpot, ypot


def compute_delta(E: float, k: float, w: float, b: float, V0: float) -> float:
    """
    概要:
        Kronig-Penneyモデルのデルタ関数 (delta(E)) の値を計算します。
    詳細説明:
        与えられたエネルギー E、波数 k、井戸の幅 w、ポテンシャル障壁の幅 b、
        およびポテンシャルの高さ V0 に基づいて、Kronig-Penneyモデルのデルタ関数を評価します。
        この関数は、電子の透過・反射特性を表し、delta(E) の値が -1 から 1 の間に収まるエネルギーが
        許容エネルギーバンドに対応します。
        計算には物理定数 HBAR, ELECTRON_MASS, ELEMENTARY_CHARGE を使用します。
    引数:
        :param E: 電子のエネルギー (eV)。
        :type E: float
        :param k: 電子の波数 (pi/a 単位)。
        :type k: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
    戻り値:
        :returns: デルタ関数の値。
        :rtype: float
    """
    alpha = sqrt(2.0 * ELECTRON_MASS * E * ELEMENTARY_CHARGE) / HBAR
    beta = sqrt(2.0 * ELECTRON_MASS * (V0 - E) * ELEMENTARY_CHARGE) / HBAR
    ka = k * PI2
    alphaw = alpha * w * 1.0e-10
    betab = beta * b * 1.0e-10

    delta = (
        (beta * beta - alpha * alpha) / 2.0 / alpha / beta * sin(alphaw) * sinh(betab)
        + cos(alphaw) * cosh(betab)
        - cos(ka)
    )
    return float(delta)


def compute_boundary_matrix(k: float, E: float, w: float, b: float, V0: float) -> np.ndarray:
    """
    概要:
        波動関数の境界条件を記述する4x4行列を計算します。
    詳細説明:
        Kronig-Penneyモデルにおける各領域 (井戸と障壁) での波動関数の係数間の関係を記述する行列を構築します。
        この行列は、井戸と障壁の界面における波動関数とその導関数の連続性条件に基づいています。
        電子のエネルギー E、波数 k、井戸の幅 w、障壁の幅 b、およびポテンシャルの高さ V0 が入力として必要です。
    引数:
        :param k: 電子の波数 (pi/a 単位)。
        :type k: float
        :param E: 電子のエネルギー (eV)。
        :type E: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
    戻り値:
        :returns: 境界条件を記述する4x4の複素数行列。
        :rtype: numpy.ndarray
    """
    alpha = sqrt(2.0 * ELECTRON_MASS * E * ELEMENTARY_CHARGE) / HBAR
    beta = sqrt(2.0 * ELECTRON_MASS * (V0 - E) * ELEMENTARY_CHARGE) / HBAR
    ka = k * PI2
    lambda_ = exp(1.0j * ka)
    alphaw = alpha * w * 1.0e-10
    betab = beta * b * 1.0e-10
    alpha *= 1.0e-10
    beta *= 1.0e-10

    matrix = np.empty((4, 4), dtype=complex)
    matrix[0, 0] = matrix[0, 1] = 1.0
    matrix[0, 2] = matrix[0, 3] = -1.0
    matrix[1, 0] = 1.0j * alpha
    matrix[1, 1] = -1.0j * alpha
    matrix[1, 2] = -beta
    matrix[1, 3] = beta
    matrix[2, 0] = exp(1.0j * alphaw)
    matrix[2, 1] = exp(-1.0j * alphaw)
    matrix[2, 2] = -lambda_ * exp(-betab)
    matrix[2, 3] = -lambda_ * exp(betab)
    matrix[3, 0] = 1.0j * alpha * exp(1.0j * alphaw)
    matrix[3, 1] = -1.0j * alpha * exp(-1.0j * alphaw)
    matrix[3, 2] = -lambda_ * beta * exp(-betab)
    matrix[3, 3] = lambda_ * beta * exp(betab)
    return matrix


def compute_wave_coefficients(k: float, E: float, w: float, b: float, V0: float) -> list[complex]:
    """
    概要:
        波動関数の係数を計算します。
    詳細説明:
        周期ポテンシャル中の電子の波動関数は、Blochの定理により周期部分と平面波部分に分けられます。
        この関数は、境界条件行列を解くことで、井戸と障壁領域における波動関数の未知の係数 A, B, C, D を計算します。
        特に、井戸領域の左側から右側への入射波の係数を 1.0 と仮定し、残りの3つの係数を連立一次方程式の解として求めます。
    引数:
        :param k: 電子の波数 (pi/a 単位)。
        :type k: float
        :param E: 電子のエネルギー (eV)。
        :type E: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
    戻り値:
        :returns: 波動関数の4つの複素数係数 (A, B, C, D) のリスト。
        :rtype: list
    """
    matrix = compute_boundary_matrix(k, E, w, b, V0)

    a0 = 1.0
    matrix3 = np.empty((3, 3), dtype=complex)
    vector3 = np.empty((3, 1), dtype=complex)

    matrix3[0, 0] = matrix[1, 1]
    matrix3[0, 1] = matrix[1, 2]
    matrix3[0, 2] = matrix[1, 3]
    matrix3[1, 0] = matrix[2, 1]
    matrix3[1, 1] = matrix[2, 2]
    matrix3[1, 2] = matrix[2, 3]
    matrix3[2, 0] = matrix[3, 1]
    matrix3[2, 1] = matrix[3, 2]
    matrix3[2, 2] = matrix[3, 3]

    vector3[0, 0] = -a0 * matrix[1, 0]
    vector3[1, 0] = -a0 * matrix[2, 0]
    vector3[2, 0] = -a0 * matrix[3, 0]

    ai = LA.solve(matrix3, vector3)
    return [a0, ai[0, 0], ai[1, 0], ai[2, 0]]


def check_wave_coefficients(
    ci: list[complex],
    kw: float,
    E: float,
    w: float,
    b: float,
    V0: float,
    eps: float,
    is_print: int = 0,
) -> None:
    """
    概要:
        計算された波動関数係数が境界条件を満たしているか検証します。
    詳細説明:
        計算された波動関数の係数 ci を用いて、境界条件行列 Mij との積 Mij @ ci が
        ほぼゼロになっていることを確認します。これは、係数が正しく計算されているかの健全性チェックです。
        積の絶対値の最大値が許容誤差 eps を超える場合、RuntimeErrorを発生させます。
    引数:
        :param ci: 波動関数の複素数係数のリスト。
        :type ci: list
        :param kw: 電子の波数 (pi/a 単位)。
        :type kw: float
        :param E: 電子のエネルギー (eV)。
        :type E: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
        :param eps: 許容誤差。
        :type eps: float
        :param is_print: 結果を標準出力に表示するかどうかを示すフラグ (0:表示しない, 1:表示する)。
        :type is_print: int
    戻り値:
        なし。
    例外:
        :raises RuntimeError: 係数が境界条件を満たさない場合に発生します。
    """
    matrix = compute_boundary_matrix(kw, E, w, b, V0)
    values = matrix @ np.asarray(ci, dtype=complex)

    if is_print:
        for i, coef in enumerate(ci):
            print(f"  ci[{i}] = {coef.real:12.4g}+j{coef.imag:12.4g}")
        for i, value in enumerate(values):
            print(f"  abs(Mij@ci[{i}]) = {abs(value)} {eps}")

    vmax = float(np.max(np.abs(values)))
    if vmax > eps:
        raise RuntimeError(f"Mij @ ci is not zero: abs(Mij@ci)={vmax} > eps={eps}")


def compute_refined_energy(
    E0: float,
    E1: float,
    k: float,
    w: float,
    b: float,
    V0: float,
    ctx: SimpleNamespace,
    is_print: int = 0,
) -> tuple[float | None, float | None, float | None]:
    """
    概要:
        ニュートン法を用いてエネルギー準位を精密化します。
    詳細説明:
        Kronig-Penneyモデルのデルタ関数 compute_delta がゼロに近づくエネルギー E を、
        初期推定値 E0 と E1 からニュートン法を適用して見つけます。
        許容誤差 ctx.eps または最大反復回数 ctx.nmaxiter に達するまで反復計算を行います。
        収束しない場合はNoneを返します。
    引数:
        :param E0: エネルギーの初期推定値1 (eV)。
        :type E0: float
        :param E1: エネルギーの初期推定値2 (eV)。
        :type E1: float
        :param k: 電子の波数 (pi/a 単位)。
        :type k: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
        :param is_print: 結果を標準出力に表示するかどうかを示すフラグ (0:表示しない, 1:表示する)。
        :type is_print: int
    戻り値:
        :returns: 精密化されたエネルギー、エネルギー変化 dE、およびデルタ関数の値のタプル。
                  収束しない場合は (None, None, None) を返します。
        :rtype: tuple
    """
    delta0 = compute_delta(E0, k, w, b, V0)
    delta1 = compute_delta(E1, k, w, b, V0)

    E2 = E1
    dE = 0.0
    delta2 = delta1

    for _ in range(ctx.nmaxiter):
        diff = (delta1 - delta0) / (E1 - E0)
        if diff >= 0.0:
            diff += ctx.dump
        else:
            diff = -(abs(diff) + ctx.dump)

        dE = -delta1 / diff
        E2 = E1 + dE
        delta2 = compute_delta(E2, k, w, b, V0)

        if abs(dE) < ctx.eps:
            if is_print:
                print(f"  converged at E = {E2:12.6g} with dE = {dE:12.6g}  delta = {delta2:12.6g}")
            return E2, dE, delta2

        E0 = E1
        E1 = E2
        delta0 = delta1
        delta1 = delta2

    print(f"  Not converged for {ctx.nmaxiter} iterations.")
    print(f"    E = {E2:12.6g} with dE = {dE:12.6g}  delta = {delta2:12.6g}")
    return None, None, None


def compute_energy_levels(
    emin: float,
    emax: float,
    nEsearch: int,
    k: float,
    w: float,
    b: float,
    V0: float,
    ctx: SimpleNamespace,
) -> tuple[list[float], list[list[complex]]]:
    """
    概要:
        指定されたエネルギー範囲内で許容されるエネルギー準位を計算します。
    詳細説明:
        エネルギー範囲 emin から emax を nEsearch 個のステップで探索し、
        compute_delta 関数の符号が反転する点を特定します。
        これらの点を初期値として compute_refined_energy を呼び出し、
        各エネルギー準位を精密化します。対応する波動関数の係数も計算して返します。
    引数:
        :param emin: 探索開始エネルギー (eV)。
        :type emin: float
        :param emax: 探索終了エネルギー (eV)。
        :type emax: float
        :param nEsearch: エネルギー探索点の数。
        :type nEsearch: int
        :param k: 電子の波数 (pi/a 単位)。
        :type k: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: 許容されるエネルギー準位のリストと、それぞれのエネルギー準位に対応する波動関数係数のリストのタプル。
        :rtype: tuple
    """
    estep = (emax - emin) / (nEsearch - 1)
    previous_delta = None
    iband = 0
    energy_list = []
    coefficient_list = []

    for iE in range(nEsearch):
        E = emin + iE * estep
        if E == 0.0:
            continue
        if V0 <= E:
            break

        delta = compute_delta(E, k, w, b, V0)

        if previous_delta is None:
            previous_delta = delta
            continue

        if previous_delta * delta < 0.0:
            previous_delta = delta
            refined_E, dE, delta0 = compute_refined_energy(E - estep, E, k, w, b, V0, ctx, is_print=0)
            print(f"  E[{iband}]={refined_E:12.6g} eV  dE={dE:12.6g} delta={delta0:12.6g}")

            if refined_E is not None:
                energy_list.append(refined_E)
                coefficient_list.append(compute_wave_coefficients(k, refined_E, w, b, V0))
                iband += 1

    return energy_list, coefficient_list


def compute_wavefunction(ci: list[complex], x: float, kw: float, E: float, w: float, b: float, V0: float) -> complex:
    """
    概要:
        指定された位置 x における波動関数の値を計算します。
    詳細説明:
        計算された波動関数係数 ci を使用して、Blochの定理に基づく波動関数を評価します。
        位置 x を周期 a で正規化し、それが井戸領域か障壁領域かに応じて適切な波動関数形式を適用します。
        結果は複素数値の波動関数となります。
    引数:
        :param ci: 波動関数の複素数係数のリスト (A, B, C, D)。
        :type ci: list
        :param x: 波動関数の値を評価するx座標 (A)。
        :type x: float
        :param kw: 電子の波数 (pi/a 単位)。
        :type kw: float
        :param E: 電子のエネルギー (eV)。
        :type E: float
        :param w: 井戸の幅 (A)。
        :type w: float
        :param b: ポテンシャル障壁の幅 (A)。
        :type b: float
        :param V0: ポテンシャルの高さ (eV)。
        :type V0: float
    戻り値:
        :returns: xにおける波動関数の複素数値。
        :rtype: complex
    例外:
        :raises ValueError: 内部でのx座標の正規化が不正な場合に発生します。
    """
    a = w + b
    xmin = -b
    xmax = w
    x0, n_period = round01(x, a)

    if x0 < -xmin:
        x0 += a
    if x0 >= xmax:
        x0 -= a
    if not xmin <= x0 < xmax:
        raise ValueError(f"x0 out of range: x={x:8.4g} {n_period} x0={x0:8.4g} w={w:8.4g} b={b:8.4g}")

    alpha = sqrt(2.0 * ELECTRON_MASS * E * ELEMENTARY_CHARGE) / HBAR * 1.0e-10
    beta = sqrt(2.0 * ELECTRON_MASS * (V0 - E) * ELEMENTARY_CHARGE) / HBAR * 1.0e-10
    phase0 = PI2 / a * kw * x0
    kph0 = exp(1.0j * phase0)

    if xmin <= x0 < 0.0:
        phi = ci[2] * exp(beta * x0) + ci[3] * exp(-beta * x0)
        periodic_part = phi / kph0
    else:
        phi = ci[0] * exp(1.0j * alpha * x0) + ci[1] * exp(-1.0j * alpha * x0)
        periodic_part = phi / kph0

    return exp(1.0j * PI2 / a * kw * x) * periodic_part + 0.0j


def normalize_coefficients(
    ci: list[complex],
    E: float,
    kw: float,
    xstep: float,
    ctx: SimpleNamespace,
) -> list[complex]:
    """
    概要:
        波動関数の係数を正規化します。
    詳細説明:
        波動関数の全空間での確率密度積分 (abs(psi(x))^2 の積分) が1になるように、
        波動関数の係数 ci を調整します。積分は周期 a の範囲で数値的に行われます。
        これにより、物理的な解釈が可能な波動関数が得られます。
    引数:
        :param ci: 正規化する前の波動関数係数のリスト。
        :type ci: list
        :param E: 電子のエネルギー (eV)。
        :type E: float
        :param kw: 電子の波数 (pi/a 単位)。
        :type kw: float
        :param xstep: 積分に使用するxステップサイズ。
        :type xstep: float
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: 正規化された波動関数係数のリスト。
        :rtype: list
    """
    nxintg = int(ctx.a / xstep + 1.0001)
    xintgstep = ctx.a / (nxintg - 1)
    charge = 0.0

    for i in range(nxintg):
        x = i * xintgstep
        yval = compute_wavefunction(ci, x, kw, E, ctx.w, ctx.b, ctx.V0)
        charge += yval * yval.conjugate()

    charge = charge.real * xintgstep
    coefficient = 1.0 / sqrt(charge)

    print("integ(|psi(x)|^2) = ", charge)
    print("Normalization coefficient = ", coefficient)

    return [c * coefficient for c in ci]


def compute_graph_data(ctx: SimpleNamespace) -> tuple[list[float], list[float]]:
    """
    概要:
        delta(E) グラフプロット用のデータを計算します。
    詳細説明:
        コンテキストオブジェクト (ctx) で指定されたエネルギー範囲 (emin から emax) と
        ステップ数 (nE) に基づいて、各エネルギー E におけるデルタ関数 (compute_delta) の値を計算します。
        このデータは、許容エネルギーバンドを視覚化するために使用されます。
    引数:
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: エネルギー値のリストと対応するデルタ関数の値のリストのタプル。
        :rtype: tuple
    """
    estep = (ctx.emax - ctx.emin) / (ctx.nE - 1)
    xE = []
    yD = []

    for i in range(1, ctx.nE):
        E = ctx.emin + i * estep
        if ctx.V0 <= E:
            break
        xE.append(E)
        yD.append(compute_delta(E, ctx.kg, ctx.w, ctx.b, ctx.V0))

    return xE, yD


def compute_band_data(ctx: SimpleNamespace) -> tuple[list[float], np.ndarray, int]:
    """
    概要:
        バンド構造プロット用のデータを計算します。
    詳細説明:
        コンテキストオブジェクト (ctx) で指定された波数範囲 (kmin から kmax) と
        ステップ数 (nk) に基づいて、各波数 k における許容エネルギー準位を計算します。
        これらのエネルギー準位は、compute_energy_levels 関数を用いて探索され、
        バンド構造としてプロットされます。
    引数:
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: 波数 k のリスト、各波数におけるエネルギー準位のnumpy.ndarray、および見つかった最大バンドレベル数のタプル。
        :rtype: tuple
    """
    kstep = (ctx.kmax - ctx.kmin) / (ctx.nk - 1)
    xk = [ctx.kmin + i * kstep for i in range(ctx.nk)]
    yE = np.zeros((ctx.nMaxLevel, ctx.nk))
    n_max_band = 0

    for ik, k in enumerate(xk):
        print(f"at k={k:8.4g}")
        energy_list, _ = compute_energy_levels(0.0, ctx.V0, ctx.nEsearch, k, ctx.w, ctx.b, ctx.V0, ctx)
        n_energy = len(energy_list)
        n_max_band = max(n_max_band, n_energy)

        for iband in range(min(n_energy, ctx.nMaxLevel)):
            yE[iband, ik] = energy_list[iband]

    return xk, yE, n_max_band


def compute_wavefunction_data(ctx: SimpleNamespace) -> tuple[np.ndarray, np.ndarray, list[float], np.ndarray, np.ndarray]:
    """
    概要:
        波動関数プロット用のデータを計算します。
    詳細説明:
        コンテキストオブジェクト (ctx) で指定された波数 (kw) とエネルギー準位インデックス (iLevel) に基づいて、
        対応する波動関数とその確率密度を計算します。
        また、ポテンシャルプロファイルも計算し、波動関数と共にプロットできるように準備します。
        波動関数は計算後に正規化されます。
    引数:
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        :returns: x座標、ポテンシャル値、エネルギー準位のリスト、波動関数の複素数値、確率密度のタプル。
        :rtype: tuple
    例外:
        :raises IndexError: 指定されたiLevelが利用可能なエネルギー準位の範囲外である場合に発生します。
    """
    xwstep = (ctx.xwmax - ctx.xwmin) / (ctx.nxw - 1)
    xplot, ypot = compute_potential_profile(ctx.xwmin, xwstep, ctx.nxw, ctx)

    energy_list, coefficient_list = compute_energy_levels(0.0, ctx.V0, ctx.nEsearch, ctx.kw, ctx.w, ctx.b, ctx.V0, ctx)
    if ctx.iLevel >= len(energy_list):
        raise IndexError(f"iLevel={ctx.iLevel} is out of range. Found {len(energy_list)} energy levels.")

    E = energy_list[ctx.iLevel]
    ci = coefficient_list[ctx.iLevel]

    print("")
    print("=== Calculate wave function ===")
    print("Energy levels:", energy_list, "eV")
    print(f"at k = {ctx.kw}")
    print(f"{ctx.iLevel}-th energy level")
    print(f"  E = {E:12.6g} eV")
    print_coefficients(ci)

    sumci = abs(ci[0] + ci[1] - ci[2] - ci[3])
    alpha = sqrt(2.0 * ELECTRON_MASS * E * ELEMENTARY_CHARGE) / HBAR * 1.0e-10
    beta = sqrt(2.0 * ELECTRON_MASS * (ctx.V0 - E) * ELEMENTARY_CHARGE) / HBAR * 1.0e-10
    print(f"  sum(ci) = {sumci:12.4e}")
    print(f"  alpha = {alpha:12.6g} A^-1")
    print(f"  beta  = {beta:12.6g} A^-1")

    print("")
    print("Normalization")
    ci = normalize_coefficients(ci, E, ctx.kw, xwstep, ctx)
    print_coefficients(ci)

    ywf = np.array(
        [compute_wavefunction(ci, ctx.xwmin + i * xwstep, ctx.kw, E, ctx.w, ctx.b, ctx.V0) for i in range(ctx.nxw)],
        dtype=complex,
    )
    charge = np.array([(value * value.conjugate()).real for value in ywf], dtype=float)
    return xplot, ypot, energy_list, ywf, charge


def print_coefficients(ci: list[complex]) -> None:
    """
    概要:
        波動関数の係数を整形して標準出力に表示します。
    詳細説明:
        与えられた波動関数係数 ci (通常は A, B, C, D に対応) を、
        実部と虚部に分けて読みやすい形式でコンソールに出力します。
    引数:
        :param ci: 波動関数の複素数係数のリスト。
        :type ci: list
    戻り値:
        なし。
    """
    names = ["A", "B", "C", "D"]
    for name, value in zip(names, ci):
        print(f"  {name} = {value.real:12.4g}+j{value.imag:12.4g}")


# =============================================================================
# plot functions
# =============================================================================
def plot_graph(ctx: SimpleNamespace) -> None:
    """
    概要:
        Kronig-Penneyモデルのdelta(E)グラフをプロットします。
    詳細説明:
        与えられたコンテキストオブジェクト (ctx) のパラメータに基づいて、
        エネルギー E に対するデルタ関数 (compute_delta) の値を計算し、プロットします。
        デルタ関数の絶対値が1を超える領域は許容エネルギーバンドに対応しません。
        生成された図は、ctx.save と ctx.show の設定に応じて保存または表示されます。
    引数:
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        なし。
    """
    estep = (ctx.emax - ctx.emin) / (ctx.nE - 1)

    print("")
    print("=== Input parameters ===")
    print("mode:", ctx.mode)
    print("a=", ctx.a, "A")
    print(f"  barrier: w={ctx.b} A  h={ctx.V0} eV")
    print(f"  well   : w={ctx.w} A  h={0.0} eV")
    print(f"Energy range: {ctx.emin} - {ctx.emax}, {estep} eV step  {ctx.nE} points")
    print(f"at k = {ctx.kg}")
    print("")

    xE, yD = compute_graph_data(ctx)

    fig = plt.figure(figsize=ctx.figsize)
    ax = fig.add_subplot(1, 1, 1)
    ax.plot(xE, yD)
    ax.set_xlim([ctx.emin, ctx.emax])
    ax.axhline(0.0, linestyle="dashed", linewidth=0.5)
    ax.set_xlabel("E (eV)", fontsize=ctx.fontsize)
    ax.set_ylabel("delta", fontsize=ctx.fontsize)
    ax.tick_params(labelsize=ctx.fontsize)
    plt.tight_layout()

    save_show_close(fig, ctx)


def plot_band(ctx: SimpleNamespace) -> None:
    """
    概要:
        Kronig-Penneyモデルのバンド構造をプロットします。
    詳細説明:
        与えられたコンテキストオブジェクト (ctx) のパラメータに基づいて、
        波数 k に対する許容エネルギー準位 (バンド) を計算し、プロットします。
        これにより、電子が占めることのできるエネルギー領域と、バンドギャップを視覚化します。
        生成された図は、ctx.save と ctx.show の設定に応じて保存または表示されます。
    引数:
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        なし。
    """
    kstep = (ctx.kmax - ctx.kmin) / (ctx.nk - 1)

    print("")
    print("=== Input parameters ===")
    print("mode:", ctx.mode)
    print("a=", ctx.a, "A")
    print(f"potential: w={ctx.bwidth} A  h={ctx.bpot} eV")
    print(f"k range: {ctx.kmin} - {ctx.kmax} at {kstep} step, {ctx.nk} points")
    print("")

    xk, yE, n_max_band = compute_band_data(ctx)

    fig = plt.figure(figsize=ctx.figsize)
    ax = fig.add_subplot(1, 1, 1)
    ax.set_xlim([-0.5, 0.5])
    ax.set_ylim(ctx.erange)

    for i_level in range(n_max_band):
        ax.plot(
            xk,
            yE[i_level],
            linestyle="",
            marker="o",
            markersize=5.0,
            markerfacecolor="none",
            markeredgewidth=0.5,
        )

    ax.set_xlabel("$k$ $(\\pi$$/a)$", fontsize=ctx.fontsize)
    ax.set_ylabel("E (eV)", fontsize=ctx.fontsize)
    ax.tick_params(labelsize=ctx.fontsize)
    plt.tight_layout()

    save_show_close(fig, ctx)


def plot_wavefunction(ctx: SimpleNamespace) -> None:
    """
    概要:
        Kronig-Penneyモデルの波動関数とその確率密度、およびポテンシャルプロファイルをプロットします。
    詳細説明:
        与えられたコンテキストオブジェクト (ctx) のパラメータに基づいて、
        特定の波数 k とエネルギー準位 iLevel における波動関数 (実部、虚部) とその確率密度 (charge) を計算します。
        これらのデータは、計算されたポテンシャルプロファイルと共に一つの図にプロットされ、
        電子の局在化とポテンシャルの関係を視覚化します。
        生成された図は、ctx.save と ctx.show の設定に応じて保存または表示されます。
    引数:
        :param ctx: 計算コンテキストを格納したSimpleNamespaceオブジェクト。
        :type ctx: types.SimpleNamespace
    戻り値:
        なし。
    """
    xwstep = (ctx.xwmax - ctx.xwmin) / (ctx.nxw - 1)

    print("")
    print("=== Input parameters ===")
    print("mode:", ctx.mode)
    print("a=", ctx.a, "A")
    print(f"Wave function to be plotted: k = {ctx.kw}  iLevel = {ctx.iLevel}")
    print(f"x range: {ctx.xwmin} - {ctx.xwmax} at {xwstep} step, {ctx.nxw} points")
    print(f"potential: w={ctx.bwidth} A  h={ctx.bpot} eV")
    print("")
    print(f"at k={ctx.kw:8.4g}")

    xplot, ypot, _, ywf, charge = compute_wavefunction_data(ctx)

    fig = plt.figure(figsize=ctx.wf_figsize)
    ax_wave = fig.add_subplot(1, 1, 1)
    ax_potential = ax_wave.twinx()

    ax_potential.set_xlim([ctx.xwmin, ctx.xwmax])
    ax_potential.plot(xplot, ypot, linewidth=0.5, label="U(x)")
    ax_potential.axhline(0.0, linestyle="dashed", linewidth=0.5)

    ax_wave.set_xlim([ctx.xwmin, ctx.xwmax])
    ax_wave.plot(xplot, ywf.real, linewidth=1.5, label="real")
    ax_wave.plot(xplot, ywf.imag, linewidth=1.5, label="imaginary")
    ax_wave.plot(xplot, charge, linewidth=0.5, label="charge")
    ax_wave.axhline(0.0, linestyle="dashed", linewidth=0.5)

    ax_potential.set_xlabel("x (A)", fontsize=ctx.fontsize)
    ax_potential.set_ylabel("U(x)", fontsize=ctx.fontsize)
    ax_wave.set_xlabel("x (A)", fontsize=ctx.fontsize)
    ax_wave.set_ylabel("$\\Psi$($x$)", fontsize=ctx.fontsize)

    handler1, label1 = ax_potential.get_legend_handles_labels()
    handler2, label2 = ax_wave.get_legend_handles_labels()
    ax_wave.legend(
        handler1 + handler2,
        label1 + label2,
        loc=2,
        borderaxespad=0.0,
        fontsize=ctx.legend_fontsize,
    )

    ax_potential.tick_params(labelsize=ctx.fontsize)
    ax_wave.tick_params(labelsize=ctx.fontsize)
    plt.tight_layout()

    save_show_close(fig, ctx)


# =============================================================================
# run / main
# =============================================================================
def run(args: argparse.Namespace) -> None:
    """
    概要:
        Kronig-Penneyモデルの計算とプロットのメイン処理を実行します。
    詳細説明:
        まずコマンドライン引数を検証し、その後コンテキストオブジェクトを作成します。
        args.mode の値に応じて、delta(E)グラフ、バンド構造、または波動関数のいずれかのプロット関数を呼び出します。
    引数:
        :param args: コマンドライン引数を格納したNamespaceオブジェクト。
        :type args: argparse.Namespace
    戻り値:
        なし。
    例外:
        :raises ValueError: args.mode が認識されない値である場合に発生します。
    """
    validate_args(args)
    ctx = create_context(args)

    if ctx.mode == "graph":
        plot_graph(ctx)
    elif ctx.mode == "band":
        plot_band(ctx)
    elif ctx.mode == "wf":
        plot_wavefunction(ctx)
    else:
        raise ValueError(f"Invalid mode: {ctx.mode}")


def main() -> None:
    """
    概要:
        スクリプトのエントリポイントです。
    詳細説明:
        コマンドライン引数を解析し、run関数を呼び出してKronig-Penneyモデルの計算とプロットを実行します。
        例外が発生した場合は、エラーメッセージを表示してプログラムを終了します。
    引数:
        なし。
    戻り値:
        なし。
    """
    args = parse_args()
    run(args)


if __name__ == "__main__":
    try:
        main()
    except Exception as exc:
        print("")
        print(f"Error: {exc}")
        traceback.print_exc()
        input("\nPress ENTER to terminate>>\n")
        sys.exit(1)