"""
Slater-Kosterタイトバインディングモデルを実装するモジュール。

s, p, d軌道を持つ有限クラスター（分子）用の最小限のSlater-Kosterタイトバインディング（TB）モデルを提供します。
直交基底（S = I）と実対称ハミルトニアン（H）を仮定します。ハリスン則による積分パラメータのスケーリングに対応しています。

:doc:`tb_harrison_usage`
"""
import numpy as np
from scipy.linalg import eigh


class SlaterKosterTB:
    """Slater-Kosterタイトバインディング法を実装するクラス。

    s, p, d軌道を持つ有限クラスター（分子）用の最小限のSlater-Kosterタイトバインディング（TB）モデルを提供します。
    直交基底（S = I）と実対称ハミルトニアン（H）を仮定します。

    ハリスン則によるスケーリングをサポートしており、`sk_params` の "harrison" エントリを通じて積分キーごとの指数を指定できます。

    `harrison` パラメータの例:
    ::

        "harrison": {
            "d0": 1.0,
            "power_default": 2.0,
            "power_map": {
                "sd_sigma": 3.5,
                "pd_sigma": 3.5,
                "pd_pi": 3.5,
                # "dd_sigma": 5.0, "dd_pi": 5.0, "dd_delta": 5.0,  # if you implement dd later
            },
            # オプション: キーごとのd0
            # "d0_map": {"pd_sigma": 2.0}
        }

    `harrison` が省略された場合、スケーリングは行われません（スケールは1）。
    後方互換性のため、`{"d0":..., "power":...}` のみが与えられた場合、それがデフォルトとして使用されます。
    """

    def __init__(self):
        """SlaterKosterTBクラスの新しいインスタンスを初期化します。

        既知の軌道タイプ（s, p, d）とそれに対応する軌道のリストを設定し、
        原子リストとハミルトニアン行列を空の状態に初期化します。
        """
        self.orbitals = {
            "s": ["s"],
            "p": ["px", "py", "pz"],
            "d": ["dxy", "dyz", "dzx", "dx2-y2", "dz2"],
        }
        self.atoms = []
        self.hamiltonian = None

    def add_atom(self, symbol, orb_type, x, y, z):
        """モデルに原子を追加します。

        指定された記号、軌道タイプ、および空間座標を持つ原子を内部リストに追加します。
        軌道タイプが`orbitals`辞書に存在しない場合はエラーを発生させます。

        :param symbol: str: 原子の元素記号。
        :param orb_type: str: 原子の軌道タイプ（例: "s", "p", "d"）。
        :param x: float: 原子のX座標。
        :param y: float: 原子のY座標。
        :param z: float: 原子のZ座標。
        :returns: None
        :raises ValueError: 未知の軌道タイプが指定された場合。
        """
        if orb_type not in self.orbitals:
            raise ValueError(f"Unknown orb_type='{orb_type}'. Choose from {list(self.orbitals.keys())}")
        self.atoms.append(
            {
                "symbol": symbol,
                "orb_type": orb_type,
                "pos": np.array([x, y, z], dtype=float),
                "orbs": self.orbitals[orb_type],
            }
        )

    @staticmethod
    def _need_param(p, key):
        """指定されたSKパラメータが存在するか確認し、その値を返します。

        `sk_params` 辞書から `key` に対応する値を取得します。
        `key` が存在しない場合は `KeyError` を発生させます。

        :param p: dict: Slater-Kosterパラメータを格納する辞書。
        :param key: str: 取得するパラメータのキー。
        :returns: float or int: 指定されたキーに対応するパラメータ値。
        :raises KeyError: 指定されたキーが `sk_params` に存在しない場合。
        """
        if key not in p:
            raise KeyError(f"Missing SK parameter '{key}' in sk_params.")
        return p[key]

    @staticmethod
    def _harrison_scale(d, sk_params, key=None):
        """ハリソン則に基づいて、距離によるスケーリング因子を計算します。

        サイト間距離 `d` と `sk_params` の "harrison" セクションに定義されたルールに従って、
        スケーリング因子を計算します。
        `d0` (基準距離) と `power` (指数) は、グローバルまたはキーごとに指定できます。
        `harrison` パラメータが指定されていない場合、または `d` が0以下の場合、スケーリングは行われず1.0を返します。
        指数の選択順序は、1) `power_map[key]` -> 2) `power_default` -> 3) 従来の `power` です。

        :param d: float: サイト間の距離。
        :param sk_params: dict: Slater-Kosterパラメータを含む辞書。
        :param key: str, optional: スケーリング対象の積分キー（例: "ss_sigma"）。キー固有のスケーリング設定に使用されます。デフォルトはNone。
        :returns: float: 計算されたハリスン則によるスケーリング因子。
        """
        hs = sk_params.get("harrison", None)
        if hs is None:
            return 1.0
        if d <= 0:
            return 1.0

        # d0 (キーごとに上書き可能)
        d0 = float(hs.get("d0", 1.0))
        d0_map = hs.get("d0_map", None)
        if key is not None and isinstance(d0_map, dict) and key in d0_map:
            d0 = float(d0_map[key])

        # 指数 (power) の選択順序:
        # 1) power_map[key]
        # 2) power_default
        # 3) 従来の power
        power_map = hs.get("power_map", None)
        if key is not None and isinstance(power_map, dict) and key in power_map:
            power = float(power_map[key])
        elif "power_default" in hs:
            power = float(hs["power_default"])
        else:
            power = float(hs.get("power", 2.0))  # 従来の形式

        return (d0 / d) ** power

    def _scaled_param(self, d, sk_params, key):
        """ハリスン則でスケーリングされたSlater-Kosterパラメータを返します。

        まず `_need_param` を使用して元のパラメータ値を取得し、
        次に `_harrison_scale` を使用してスケーリング因子を計算します。
        最後に、元のパラメータ値にスケーリング因子を乗じて結果を返します。

        :param d: float: サイト間の距離。
        :param sk_params: dict: Slater-Kosterパラメータを含む辞書。
        :param key: str: スケーリング対象のSlater-Koster積分キー。
        :returns: float: ハリスン則でスケーリングされたパラメータ値。
        """
        v = self._need_param(sk_params, key)
        s = self._harrison_scale(d, sk_params, key=key)
        return v * s

    def _get_sk_elements(self, vec, orb1, orb2, p):
        """2つの軌道間のSlater-Koster行列要素を計算します。

        2つの原子間の相対位置ベクトル `vec` とそれぞれの軌道タイプ `orb1`, `orb2` を基に、
        ハミルトニアンのオフサイト要素を計算します。
        距離 `d`、方向余弦 `l, m, n` を用いて、s-s, s-p, s-d, p-p, p-d の各相互作用を処理します。
        d-d 相互作用は現在未実装です。
        軌道間の対称性に基づいて、特定の相互作用では符号が反転することがあります。

        :param vec: numpy.ndarray: 原子間の相対位置ベクトル `(x, y, z)`。
        :param orb1: str: 1番目の原子の軌道タイプ（例: "s", "px", "dxy"）。
        :param orb2: str: 2番目の原子の軌道タイプ（例: "s", "px", "dxy"）。
        :param p: dict: Slater-Kosterパラメータを含む辞書。
        :returns: float: 計算されたSlater-Koster行列要素の値。
        :raises NotImplementedError: d-d ホッピングまたはサポートされていない軌道ペアが指定された場合。
        """
        d = np.linalg.norm(vec)
        if d < 1e-12:
            return 0.0

        l, m, n = vec / d

        # quick type flags
        is_s1, is_s2 = (orb1 == "s"), (orb2 == "s")
        is_p1, is_p2 = orb1.startswith("p"), orb2.startswith("p")
        is_d1, is_d2 = orb1.startswith("d"), orb2.startswith("d")

        # 未実装チャネルでの高速失敗
        if is_d1 and is_d2:
            raise NotImplementedError("d-d hopping (dd_sigma, dd_pi, dd_delta) is not implemented.")

        # --- 1. s-s ---
        if is_s1 and is_s2:
            return self._scaled_param(d, p, "ss_sigma")

        # --- 2. s-p ---
        if is_s1 and is_p2:
            vsp = self._scaled_param(d, p, "sp_sigma")
            if orb2 == "px":
                return l * vsp
            if orb2 == "py":
                return m * vsp
            if orb2 == "pz":
                return n * vsp

        if is_p1 and is_s2:
            return -self._get_sk_elements(vec, orb2, orb1, p)

        # --- 3. s-d ---
        if is_s1 and is_d2:
            vsd = self._scaled_param(d, p, "sd_sigma")
            if orb2 == "dz2":
                return (n**2 - 0.5 * (l**2 + m**2)) * vsd
            if orb2 == "dx2-y2":
                return 0.5 * np.sqrt(3.0) * (l**2 - m**2) * vsd
            if orb2 == "dxy":
                return np.sqrt(3.0) * l * m * vsd
            if orb2 == "dyz":
                return np.sqrt(3.0) * m * n * vsd
            if orb2 == "dzx":
                return np.sqrt(3.0) * l * n * vsd

        if is_d1 and is_s2:
            return self._get_sk_elements(vec, orb2, orb1, p)

        # --- 4. p-p ---
        if is_p1 and is_p2:
            vpps = self._scaled_param(d, p, "pp_sigma")
            vppp = self._scaled_param(d, p, "pp_pi")
            coord = {"px": l, "py": m, "pz": n}
            c1, c2 = coord[orb1], coord[orb2]
            if orb1 == orb2:
                return c1**2 * vpps + (1.0 - c1**2) * vppp
            else:
                return c1 * c2 * (vpps - vppp)

        # --- 5. p-d ---
        if is_p1 and is_d2:
            vpds = self._scaled_param(d, p, "pd_sigma")
            vpdp = self._scaled_param(d, p, "pd_pi")

            if orb1 == "px":
                if orb2 == "dxy":
                    return m * (l**2 * vpds + (1 - 2 * l**2) * vpdp)
                if orb2 == "dyz":
                    return l * m * n * (vpds - 2 * vpdp)
                if orb2 == "dzx":
                    return n * (l**2 * vpds + (1 - 2 * l**2) * vpdp)
                if orb2 == "dx2-y2":
                    return 0.5 * l * (l**2 - m**2) * vpds + l * (1 - l**2 + m**2) * vpdp
                if orb2 == "dz2":
                    return l * (n**2 - 0.5 * (l**2 + m**2)) * vpds - np.sqrt(3.0) * l * n**2 * vpdp

            if orb1 == "py":
                if orb2 == "dxy":
                    return l * (m**2 * vpds + (1 - 2 * m**2) * vpdp)
                if orb2 == "dyz":
                    return n * (m**2 * vpds + (1 - 2 * m**2) * vpdp)
                if orb2 == "dzx":
                    return l * m * n * (vpds - 2 * vpdp)
                if orb2 == "dx2-y2":
                    return 0.5 * m * (l**2 - m**2) * vpds - m * (1 + l**2 - m**2) * vpdp
                if orb2 == "dz2":
                    return m * (n**2 - 0.5 * (l**2 + m**2)) * vpds - np.sqrt(3.0) * m * n**2 * vpdp

            if orb1 == "pz":
                if orb2 == "dxy":
                    return l * m * n * (vpds - 2 * vpdp)
                if orb2 == "dyz":
                    return m * (n**2 * vpds + (1 - 2 * n**2) * vpdp)
                if orb2 == "dzx":
                    return l * (n**2 * vpds + (1 - 2 * n**2) * vpdp)
                if orb2 == "dx2-y2":
                    return 0.5 * n * (l**2 - m**2) * vpds - n * (l**2 - m**2) * vpdp
                if orb2 == "dz2":
                    return n * (n**2 - 0.5 * (l**2 + m**2)) * vpds + np.sqrt(3.0) * n * (l**2 + m**2) * vpdp

        if is_d1 and is_p2:
            return -self._get_sk_elements(vec, orb2, orb1, p)

        raise NotImplementedError(f"Unsupported orbital pair: ({orb1}, {orb2})")

    def build_hamiltonian(self, sk_params, onsite_energies):
        """システムのハミルトニアン行列を構築します。

        追加された原子とその軌道情報、Slater-Kosterパラメータ、オンサイトエネルギーに基づいて、
        全体のハミルトニアン行列を構築します。
        まず、オンサイト項を対角要素に設定します。次に、異なる原子間のオフサイト項を
        `_get_sk_elements` を使用して計算し、行列を埋めます。
        ハミルトニアンは実対称行列として構築されます。

        :param sk_params: dict: Slater-Kosterパラメータを含む辞書。
        :param onsite_energies: dict: 原子の記号と軌道タイプごとのオンサイトエネルギーを含む辞書。
        :returns: numpy.ndarray: 構築されたハミルトニアン行列。
        """
        total_orbs = sum(len(a["orbs"]) for a in self.atoms)
        H = np.zeros((total_orbs, total_orbs), dtype=float)

        # Onsite terms
        idx = 0
        for a in self.atoms:
            n_orb = len(a["orbs"])
            e_on = onsite_energies.get(a["symbol"], {}).get(a["orb_type"], 0.0)
            for k in range(n_orb):
                H[idx + k, idx + k] = e_on
            idx += n_orb

        # Offsite terms
        idx_i = 0
        for i, atom_i in enumerate(self.atoms):
            idx_j = 0
            for j, atom_j in enumerate(self.atoms):
                if i < j:
                    vec = atom_j["pos"] - atom_i["pos"]
                    for io, orb_i in enumerate(atom_i["orbs"]):
                        for jo, orb_j in enumerate(atom_j["orbs"]):
                            val = self._get_sk_elements(vec, orb_i, orb_j, sk_params)
                            H[idx_i + io, idx_j + jo] = val
                            H[idx_j + jo, idx_i + io] = val
                idx_j += len(atom_j["orbs"])
            idx_i += len(atom_i["orbs"])

        self.hamiltonian = H
        return H

    def orbital_labels(self):
        """各原子の軌道に対応するラベルのリストを生成します。

        各原子の位置と軌道タイプを組み合わせて、ハミルトニアン行列の各要素に対応する
        識別のための文字列ラベルを生成します。
        例: "Ba(0.000,0.000,0.000)_s", "N(0.000,0.000,0.800)_px"

        :returns: list[str]: 各軌道に対応するラベルのリスト。
        """
        labels = []
        for a in self.atoms:
            x, y, z = a["pos"]
            for o in a["orbs"]:
                labels.append(f"{a['symbol']}({x:.3f},{y:.3f},{z:.3f})_{o}")
        return labels


def main():
    """`SlaterKosterTB` クラスの使用例を示します。

    `SlaterKosterTB` クラスを初期化し、複数の原子をモデルに追加します。
    定義済みのSlater-Kosterパラメータとオンサイトエネルギーを使用してハミルトニアンを構築し、
    `scipy.linalg.eigh` を使用してハミルトニアンを対角化し、固有値（軌道エネルギー）と固有ベクトルを計算します。
    結果として得られたエネルギーと、それぞれの固有状態における支配的な軌道成分を表示します。
    """
    tb = SlaterKosterTB()

    tb.add_atom("Ba", "d", 0, 0, 0)
    tb.add_atom("N", "p", 0, 0, 0.8)
    tb.add_atom("N", "p", 1, 0, 0)
    tb.add_atom("N", "p", -1, 0, 0)
    tb.add_atom("N", "p", 0, 1, 0)
    tb.add_atom("N", "p", 0, -1, 0)

    params = {
        "ss_sigma": -1.40,
        "sp_sigma": 1.84,
        "pp_sigma": 3.24,
        "pp_pi": -0.81,
        "sd_sigma": -2.0,
        "pd_sigma": -2.2,
        "pd_pi": 1.0,
        # キーに依存するハリスン則スケーリング:
        "harrison": {
            "d0": 1.0,
            "power_default": 2.0,  # ss/sp/pp はデフォルトで 2.0
            "power_map": {         # 選択された積分でのオーバーライド
                "sd_sigma": 3.5,
                "pd_sigma": 3.5,
                "pd_pi": 3.5,
            },
        },
    }

    onsite = {"Ba": {"d": 0.0}, "N": {"p": -3.0}}

    H = tb.build_hamiltonian(params, onsite)
    energies, vectors = eigh(H)

    idx_sort = energies.argsort()
    evals = energies[idx_sort]
    evecs = vectors[:, idx_sort]
    labels = tb.orbital_labels()

    print("Eigenvalues (Orbital Energies) [eV]:")
    print(evals)

    print("\nNo. |   Energy (eV) |  Dominant components")
    print("-" * 72)
    for i in range(len(evals)):
        coeffs = evecs[:, i]
        w = coeffs**2
        top = np.argsort(w)[::-1][:5]
#        desc = ", ".join([f"{labels[k]}:{coeffs[k]:+.2f}" for k in top])
#        print(f"{i+1:2d} | {evals[i]:8.4f} | {desc}")
        print(f"{i+1:2d} | {evals[i]:8.4f}")
        for l, c in zip(labels, coeffs):
            print(f"  {l}:{c:+.3f}")


if __name__ == "__main__":
    main()