import numpy as np
from scipy.linalg import eigh


class SlaterKosterTB:
    """
    Minimal Slater–Koster tight-binding (s, p, d) for finite clusters (molecules).

    - Orthonormal basis: S = I
    - Real symmetric Hamiltonian H

    Harrison scaling:
      sk_params["harrison"] can specify exponent per integral key.

      Example:
        "harrison": {
            "d0": 1.0,
            "power_default": 2.0,
            "power_map": {
                "sd_sigma": 3.5,
                "pd_sigma": 3.5,
                "pd_pi": 3.5,
                # "dd_sigma": 5.0, "dd_pi": 5.0, "dd_delta": 5.0,  # if you implement dd later
            },
            # optional: per-key d0
            # "d0_map": {"pd_sigma": 2.0}
        }

      If omitted, no scaling (scale=1).
      Backward compatible:
        if only {"d0":..., "power":...} is given, uses that as default.
    """

    def __init__(self):
        self.orbitals = {
            "s": ["s"],
            "p": ["px", "py", "pz"],
            "d": ["dxy", "dyz", "dzx", "dx2-y2", "dz2"],
        }
        self.atoms = []
        self.hamiltonian = None

    def add_atom(self, symbol, orb_type, x, y, z):
        if orb_type not in self.orbitals:
            raise ValueError(f"Unknown orb_type='{orb_type}'. Choose from {list(self.orbitals.keys())}")
        self.atoms.append(
            {
                "symbol": symbol,
                "orb_type": orb_type,
                "pos": np.array([x, y, z], dtype=float),
                "orbs": self.orbitals[orb_type],
            }
        )

    @staticmethod
    def _need_param(p, key):
        if key not in p:
            raise KeyError(f"Missing SK parameter '{key}' in sk_params.")
        return p[key]

    @staticmethod
    def _harrison_scale(d, sk_params, key=None):
        hs = sk_params.get("harrison", None)
        if hs is None:
            return 1.0
        if d <= 0:
            return 1.0

        # d0 (can be overridden per key)
        d0 = float(hs.get("d0", 1.0))
        d0_map = hs.get("d0_map", None)
        if key is not None and isinstance(d0_map, dict) and key in d0_map:
            d0 = float(d0_map[key])

        # exponent (power) selection order:
        # 1) power_map[key]
        # 2) power_default
        # 3) legacy power
        power_map = hs.get("power_map", None)
        if key is not None and isinstance(power_map, dict) and key in power_map:
            power = float(power_map[key])
        elif "power_default" in hs:
            power = float(hs["power_default"])
        else:
            power = float(hs.get("power", 2.0))  # legacy

        return (d0 / d) ** power

    def _scaled_param(self, d, sk_params, key):
        """Return V_key scaled by Harrison rule for this key."""
        v = self._need_param(sk_params, key)
        s = self._harrison_scale(d, sk_params, key=key)
        return v * s

    def _get_sk_elements(self, vec, orb1, orb2, p):
        d = np.linalg.norm(vec)
        if d < 1e-12:
            return 0.0

        l, m, n = vec / d

        # quick type flags
        is_s1, is_s2 = (orb1 == "s"), (orb2 == "s")
        is_p1, is_p2 = orb1.startswith("p"), orb2.startswith("p")
        is_d1, is_d2 = orb1.startswith("d"), orb2.startswith("d")

        # Fail-fast for unimplemented channels
        if is_d1 and is_d2:
            raise NotImplementedError("d-d hopping (dd_sigma, dd_pi, dd_delta) is not implemented.")

        # --- 1. s-s ---
        if is_s1 and is_s2:
            return self._scaled_param(d, p, "ss_sigma")

        # --- 2. s-p ---
        if is_s1 and is_p2:
            vsp = self._scaled_param(d, p, "sp_sigma")
            if orb2 == "px":
                return l * vsp
            if orb2 == "py":
                return m * vsp
            if orb2 == "pz":
                return n * vsp

        if is_p1 and is_s2:
            return -self._get_sk_elements(vec, orb2, orb1, p)

        # --- 3. s-d ---
        if is_s1 and is_d2:
            vsd = self._scaled_param(d, p, "sd_sigma")
            if orb2 == "dz2":
                return (n**2 - 0.5 * (l**2 + m**2)) * vsd
            if orb2 == "dx2-y2":
                return 0.5 * np.sqrt(3.0) * (l**2 - m**2) * vsd
            if orb2 == "dxy":
                return np.sqrt(3.0) * l * m * vsd
            if orb2 == "dyz":
                return np.sqrt(3.0) * m * n * vsd
            if orb2 == "dzx":
                return np.sqrt(3.0) * l * n * vsd

        if is_d1 and is_s2:
            return self._get_sk_elements(vec, orb2, orb1, p)

        # --- 4. p-p ---
        if is_p1 and is_p2:
            vpps = self._scaled_param(d, p, "pp_sigma")
            vppp = self._scaled_param(d, p, "pp_pi")
            coord = {"px": l, "py": m, "pz": n}
            c1, c2 = coord[orb1], coord[orb2]
            if orb1 == orb2:
                return c1**2 * vpps + (1.0 - c1**2) * vppp
            else:
                return c1 * c2 * (vpps - vppp)

        # --- 5. p-d ---
        if is_p1 and is_d2:
            vpds = self._scaled_param(d, p, "pd_sigma")
            vpdp = self._scaled_param(d, p, "pd_pi")

            if orb1 == "px":
                if orb2 == "dxy":
                    return m * (l**2 * vpds + (1 - 2 * l**2) * vpdp)
                if orb2 == "dyz":
                    return l * m * n * (vpds - 2 * vpdp)
                if orb2 == "dzx":
                    return n * (l**2 * vpds + (1 - 2 * l**2) * vpdp)
                if orb2 == "dx2-y2":
                    return 0.5 * l * (l**2 - m**2) * vpds + l * (1 - l**2 + m**2) * vpdp
                if orb2 == "dz2":
                    return l * (n**2 - 0.5 * (l**2 + m**2)) * vpds - np.sqrt(3.0) * l * n**2 * vpdp

            if orb1 == "py":
                if orb2 == "dxy":
                    return l * (m**2 * vpds + (1 - 2 * m**2) * vpdp)
                if orb2 == "dyz":
                    return n * (m**2 * vpds + (1 - 2 * m**2) * vpdp)
                if orb2 == "dzx":
                    return l * m * n * (vpds - 2 * vpdp)
                if orb2 == "dx2-y2":
                    return 0.5 * m * (l**2 - m**2) * vpds - m * (1 + l**2 - m**2) * vpdp
                if orb2 == "dz2":
                    return m * (n**2 - 0.5 * (l**2 + m**2)) * vpds - np.sqrt(3.0) * m * n**2 * vpdp

            if orb1 == "pz":
                if orb2 == "dxy":
                    return l * m * n * (vpds - 2 * vpdp)
                if orb2 == "dyz":
                    return m * (n**2 * vpds + (1 - 2 * n**2) * vpdp)
                if orb2 == "dzx":
                    return l * (n**2 * vpds + (1 - 2 * n**2) * vpdp)
                if orb2 == "dx2-y2":
                    return 0.5 * n * (l**2 - m**2) * vpds - n * (l**2 - m**2) * vpdp
                if orb2 == "dz2":
                    return n * (n**2 - 0.5 * (l**2 + m**2)) * vpds + np.sqrt(3.0) * n * (l**2 + m**2) * vpdp

        if is_d1 and is_p2:
            return -self._get_sk_elements(vec, orb2, orb1, p)

        raise NotImplementedError(f"Unsupported orbital pair: ({orb1}, {orb2})")

    def build_hamiltonian(self, sk_params, onsite_energies):
        total_orbs = sum(len(a["orbs"]) for a in self.atoms)
        H = np.zeros((total_orbs, total_orbs), dtype=float)

        # Onsite terms
        idx = 0
        for a in self.atoms:
            n_orb = len(a["orbs"])
            e_on = onsite_energies.get(a["symbol"], {}).get(a["orb_type"], 0.0)
            for k in range(n_orb):
                H[idx + k, idx + k] = e_on
            idx += n_orb

        # Offsite terms
        idx_i = 0
        for i, atom_i in enumerate(self.atoms):
            idx_j = 0
            for j, atom_j in enumerate(self.atoms):
                if i < j:
                    vec = atom_j["pos"] - atom_i["pos"]
                    for io, orb_i in enumerate(atom_i["orbs"]):
                        for jo, orb_j in enumerate(atom_j["orbs"]):
                            val = self._get_sk_elements(vec, orb_i, orb_j, sk_params)
                            H[idx_i + io, idx_j + jo] = val
                            H[idx_j + jo, idx_i + io] = val
                idx_j += len(atom_j["orbs"])
            idx_i += len(atom_i["orbs"])

        self.hamiltonian = H
        return H

    def orbital_labels(self):
        labels = []
        for a in self.atoms:
            x, y, z = a["pos"]
            for o in a["orbs"]:
                labels.append(f"{a['symbol']}({x:.3f},{y:.3f},{z:.3f})_{o}")
        return labels


def main():
    tb = SlaterKosterTB()

    tb.add_atom("Ba", "d", 0, 0, 0)
    tb.add_atom("N", "p", 0, 0, 0.8)
    tb.add_atom("N", "p", 1, 0, 0)
    tb.add_atom("N", "p", -1, 0, 0)
    tb.add_atom("N", "p", 0, 1, 0)
    tb.add_atom("N", "p", 0, -1, 0)

    params = {
        "ss_sigma": -1.40,
        "sp_sigma": 1.84,
        "pp_sigma": 3.24,
        "pp_pi": -0.81,
        "sd_sigma": -2.0,
        "pd_sigma": -2.2,
        "pd_pi": 1.0,
        # Key-dependent Harrison scaling:
        "harrison": {
            "d0": 1.0,
            "power_default": 2.0,  # ss/sp/pp -> 2.0 by default
            "power_map": {         # override for selected integrals
                "sd_sigma": 3.5,
                "pd_sigma": 3.5,
                "pd_pi": 3.5,
            },
        },
    }

    onsite = {"Ba": {"d": 0.0}, "N": {"p": -3.0}}

    H = tb.build_hamiltonian(params, onsite)
    energies, vectors = eigh(H)

    idx_sort = energies.argsort()
    evals = energies[idx_sort]
    evecs = vectors[:, idx_sort]
    labels = tb.orbital_labels()

    print("Eigenvalues (Orbital Energies) [eV]:")
    print(evals)

    print("\nNo. |   Energy (eV) |  Dominant components")
    print("-" * 72)
    for i in range(len(evals)):
        coeffs = evecs[:, i]
        w = coeffs**2
        top = np.argsort(w)[::-1][:5]
#        desc = ", ".join([f"{labels[k]}:{coeffs[k]:+.2f}" for k in top])
#        print(f"{i+1:2d} | {evals[i]:8.4f} | {desc}")
        print(f"{i+1:2d} | {evals[i]:8.4f}")
        for l, c in zip(labels, coeffs):
            print(f"  {l}:{c:+.3f}")


if __name__ == "__main__":
    main()
