# tkpg/core.py
import numpy as np

# ============================================================
# 0. Utilities
# ============================================================

def normalize(v, eps=1e-12):
    v = np.array(v, float)
    n = np.linalg.norm(v)
    if n < eps:
        return v * 0.0
    return v / n

def center_positions(ligands_pos):
    ligands_pos = np.array(ligands_pos, float)
    return ligands_pos - np.mean(ligands_pos, axis=0, keepdims=True)

def match_ligand_id_by_position(ligands_pos, pos_new, tol=1e-4):
    """
    ligands_pos: (N,3)
    pos_new: (3,)
    returns index of ligand that matches within tol, else None
    """
    ligands_pos = np.array(ligands_pos, float)
    pos_new = np.array(pos_new, float)
    d2 = np.sum((ligands_pos - pos_new[None, :])**2, axis=1)
    j = int(np.argmin(d2))
    if d2[j] <= tol**2:
        return j
    return None


# ============================================================
# 1. Periodic metadata from fractional coordinates
# ============================================================

def wrap_frac_and_T(frac, eps=1e-12):
    """
    frac (N,3): may contain negatives (e.g. -0.5)
    return:
      frac_wrap in [0,1)
      cell_T integer translations such that frac = frac_wrap + cell_T
    """
    frac = np.array(frac, float)

    # stabilize boundary cases like -0.5, 1.0, -1.0 with tiny eps
    frac2 = frac + eps

    cell_T = np.floor(frac2).astype(int)
    frac_wrap = frac - cell_T

    # ensure [0,1)
    frac_wrap = frac_wrap - np.floor(frac_wrap)

    # clean numerical noise
    frac_wrap = np.round(frac_wrap, 12)

    return frac_wrap, cell_T


def build_periodic_metadata_from_frac(ligands_frac, lattice, tol_frac=1e-6):
    """
    ligands_frac: (N,3) fractional positions relative to center (may be negative)
    lattice: (3,3) lattice vectors as rows: [a; b; c]
    tol_frac: clustering tolerance in fractional coordinate

    returns dict:
      ligands_pos : (N,3) Cartesian vectors (keep sign; used for symmetry ops)
      cell_T      : (N,3) integer translations to wrap into [0,1)
      site_id     : (N,)  same for periodic images
      frac_wrap   : (N,3) wrapped fractional in [0,1)
    """
    ligands_frac = np.array(ligands_frac, float)
    lattice = np.array(lattice, float)

    # Geometry vectors used for D(g): DO NOT wrap (keep +/- directions)
    ligands_pos = ligands_frac @ lattice

    # Periodic equivalence metadata
    frac_wrap, cell_T = wrap_frac_and_T(ligands_frac)

    # Cluster by wrapped fractional position (periodic site)
    # Grid hashing: robust and fast
    if tol_frac <= 0:
        tol_frac = 1e-6

#    key_int = np.rint(frac_wrap / tol_frac).astype(int)
#    keys = [tuple(k) for k in key_int]
    # robust hashing
    tol = max(float(tol_frac), 1e-12)
    key_int = np.floor(frac_wrap / tol + 0.5).astype(int)
    keys = [tuple(k.tolist()) for k in key_int]

    site_map = {}
    site_id = np.empty(len(keys), dtype=int)
    next_id = 0
    for i, k in enumerate(keys):
        if k not in site_map:
            site_map[k] = next_id
            next_id += 1
        site_id[i] = site_map[k]

    return {
        "ligands_pos": ligands_pos,
        "cell_T": cell_T,
        "site_id": site_id,
        "frac_wrap": frac_wrap,
        "n_sites": next_id,
    }


# ============================================================
# 2. Local frames & ligand basis
# ============================================================

def compute_local_frames(ligands_pos):
    """
    ligands_pos: (N,3) positions from metal center
    frames[i] = (ex, ey, ez) for ligand i; ez along r (sigma)
    """
    ligands_pos = np.array(ligands_pos, float)
    frames = {}
    for i, p in enumerate(ligands_pos):
        ez = normalize(p)
        seed = np.array([1.0, 0.0, 0.0])
        if abs(np.dot(seed, ez)) > 0.9:
            seed = np.array([0.0, 1.0, 0.0])
        ey = normalize(np.cross(ez, seed))
        ex = normalize(np.cross(ey, ez))
        frames[i] = (ex, ey, ez)
    return frames


def build_ligand_basis(N_ligands, frames, mode="full"):
    """
    mode:
      "s"     : ligand s orbitals (scalar)
      "p"     : ligand px,py,pz (global Cartesian p)
      "sigma" : ligand p_sigma only (local ez direction)
      "full"  : ligand p_sigma + p_pi1 + p_pi2 (local ez/ex/ey)
    """
    basis = []
    for lid in range(N_ligands):
        if mode == "s":
            basis.append({"ligand": lid, "type": "s", "name": "s"})
        elif mode == "p":
            for orb in ["px", "py", "pz"]:
                basis.append({"ligand": lid, "type": "p", "name": orb})
        elif mode == "sigma":
            basis.append({"ligand": lid, "type": "lp", "name": "sigma"})
        elif mode == "full":
            basis.append({"ligand": lid, "type": "lp", "name": "sigma"})
            basis.append({"ligand": lid, "type": "lp", "name": "pi1"})
            basis.append({"ligand": lid, "type": "lp", "name": "pi2"})
        else:
            raise ValueError("Unknown basis mode")
    return basis


def basis_key(b):
    return (b["ligand"], b["type"], b["name"])


def basis_vector_global(b, frames):
    """
    Return the direction vector of this basis function in GLOBAL coordinates.
    For:
      s : None (scalar)
      p : global ex/ey/ez
      lp (local p) : sigma/pi1/pi2 mapped to local (ez/ex/ey)
    """
    if b["type"] == "s":
        return None
    if b["type"] == "p":
        if b["name"] == "px": return np.array([1.0, 0.0, 0.0])
        if b["name"] == "py": return np.array([0.0, 1.0, 0.0])
        if b["name"] == "pz": return np.array([0.0, 0.0, 1.0])
        raise ValueError("bad p name")
    if b["type"] == "lp":
        ex, ey, ez = frames[b["ligand"]]
        if b["name"] == "sigma": return ez
        if b["name"] == "pi1":   return ex
        if b["name"] == "pi2":   return ey
        raise ValueError("bad lp name")
    raise ValueError("bad basis type")


def destination_components(mode, frames, dest_ligand_id):
    ex, ey, ez = frames[dest_ligand_id]
    if mode == "s":
        return [("s", None)]
    if mode == "p":
        return [
            ("px", np.array([1.0, 0.0, 0.0])),
            ("py", np.array([0.0, 1.0, 0.0])),
            ("pz", np.array([0.0, 0.0, 1.0])),
        ]
    if mode == "sigma":
        return [("sigma", ez)]
    if mode == "full":
        return [("sigma", ez), ("pi1", ex), ("pi2", ey)]
    raise ValueError("Unknown mode")


# ============================================================
# 3. Representation matrices & characters
# ============================================================

def representation_matrix_for_element(group_element, ligands_pos, frames, basis, mode, tol_match=1e-4):
    """
    Build D(g) for the given basis, allowing mixing (important for p and local p).
    D[j,i] = coefficient of basis j in image of basis i.

    - ligand centers are moved by R on positions
    - orbital orientation vectors are rotated by R, then projected onto destination basis
    - s orbitals are scalars (coefficient 1)
    """
    R = group_element["R"]
    ligands_pos = np.array(ligands_pos, float)

    n = len(basis)
    D = np.zeros((n, n), float)
    idx = {basis_key(b): i for i, b in enumerate(basis)}

    for i, b in enumerate(basis):
        lid = b["ligand"]
        r0 = ligands_pos[lid]
        r1 = R @ r0

        dest_lid = match_ligand_id_by_position(ligands_pos, r1, tol=tol_match)
        if dest_lid is None:
            continue

        if b["type"] == "s":
            j = idx.get((dest_lid, "s", "s"), None)
            if j is not None:
                D[j, i] = 1.0
            continue

        u0 = basis_vector_global(b, frames)
        u1 = R @ u0

        for name, v in destination_components(mode, frames, dest_lid):
            coeff = float(np.dot(u1, v)) if v is not None else 1.0

            if mode == "p":
                key = (dest_lid, "p", name)
            else:
                key = (dest_lid, "lp", name)

            j = idx.get(key, None)
            if j is not None and abs(coeff) > 1e-12:
                D[j, i] = coeff

    return D


def character_of_rep(group_element, ligands_pos, frames, basis, mode, tol_match=1e-4):
    D = representation_matrix_for_element(group_element, ligands_pos, frames, basis, mode, tol_match=tol_match)
    return float(np.trace(D))


def build_reducible_characters(group, ligands_pos, frames, basis, mode, tol_match=1e-4):
    return np.array(
        [character_of_rep(g, ligands_pos, frames, basis, mode, tol_match=tol_match) for g in group],
        float
    )


# ============================================================
# 4. Decomposition / Projection / SALCs (generic)
# ============================================================

def decompose_generic(Gamma, group, irreps, irrep_char_func):
    """
    n_irrep = (1/|G|) sum_g chi_Gamma(g) * chi_irrep(g)
    """
    Gorder = len(group)
    coeffs = {}
    for ir in irreps:
        chis = np.array([irrep_char_func(ir, g) for g in group], float)
        coeffs[ir] = float(np.dot(Gamma, chis) / Gorder)
    return coeffs


def projection_operator(irrep, group, irrep_dim_func, irrep_char_func,
                        ligands_pos, frames, basis, mode, tol_match=1e-4):
    """
    P = (l/|G|) sum_g chi_irrep(g) D(g)
    """
    l = irrep_dim_func(irrep)
    Gorder = len(group)
    n = len(basis)
    P = np.zeros((n, n), float)

    for g in group:
        chi = irrep_char_func(irrep, g)
        Dg = representation_matrix_for_element(g, ligands_pos, frames, basis, mode, tol_match=tol_match)
        P += chi * Dg

    P *= (l / Gorder)
    return P


def extract_salc(irrep, group, irrep_dim_func, irrep_char_func,
                 ligands_pos, frames, basis, mode, tol_match=1e-4, tol=1e-6):
    P = projection_operator(irrep, group, irrep_dim_func, irrep_char_func,
                            ligands_pos, frames, basis, mode, tol_match=tol_match)
    w, v = np.linalg.eigh(P)

    salcs = []
    for k in range(len(w)):
        if w[k] > tol:
            vec = v[:, k]
            vec = vec / np.linalg.norm(vec)
            imax = int(np.argmax(np.abs(vec)))
            if vec[imax] < 0:
                vec = -vec
            salcs.append(vec)
    return salcs


# ============================================================
# 5. Printing SALCs (with optional Bloch phase)
# ============================================================

def _fmt_complex(z, digits=4, eps=1e-12):
    """
    Robust complex formatter.
    - snap tiny real/imag to 0
    - avoid '-0.0000'
    - keep sign correctly
    """
    re = float(np.real(z))
    im = float(np.imag(z))

    if abs(re) < eps: re = 0.0
    if abs(im) < eps: im = 0.0

    # avoid negative zero
    if re == 0.0: re = 0.0
    if im == 0.0: im = 0.0

    if im == 0.0:
        return f"{re:+.{digits}f}"
    if re == 0.0:
        return f"{im:+.{digits}f}j"
    return f"{re:+.{digits}f}{im:+.{digits}f}j"


def print_salc(irrep, salcs, basis, thr=1e-3, max_salc=None, k_frac=None, cell_T=None,
               show_phase=True, show_effective=True):
    """
    Print SALCs.
    If k_frac is provided (and cell_T), coefficients are multiplied by Bloch phase:
        c_i -> c_i * exp(i 2π k·T_i)
    and printed as complex numbers (default).
    """
    if max_salc is not None:
        salcs = salcs[:max_salc]

    if k_frac is None:
        print(f"\n[SALCs for {irrep}]  (count={len(salcs)})")
    else:
        k = np.array(k_frac, float)
        print(f"\n[SALCs for {irrep} at k=({k[0]:.3f},{k[1]:.3f},{k[2]:.3f})]  (count={len(salcs)})")

    for kidx, vec in enumerate(salcs, 1):
        print(f"  SALC #{kidx}:")
        items = sorted(zip(vec, basis), key=lambda x: (x[1]["ligand"], x[1]["type"], x[1]["name"]))

        for c, b in items:
            if abs(c) < thr:
                continue

            raw = complex(float(c), 0.0)
    
            if k_frac is not None:
                if cell_T is None:
                    raise ValueError("print_salc: cell_T must be provided when k_frac is given.")
                T = np.array(cell_T[b["ligand"]], int)
                phase = np.exp(2j * np.pi * float(np.dot(k, T)))

            # 数値ノイズを軽く整える（±1, ±i に近ければスナップ）
                if abs(phase.imag) < 1e-12:
                    phase = complex(round(phase.real), 0.0)
                if abs(phase.real) < 1e-12:
                    phase = complex(0.0, round(phase.imag))

                eff = raw * phase

                s_raw   = _fmt_complex(raw)
                s_phase = _fmt_complex(phase)
                s_eff   = _fmt_complex(eff)

                if show_phase and show_effective:
                    # raw と位相で「ずれ」を見せつつ、最終的に eff が揃うのを確認できる
                    print(f"    raw={s_raw:>10s}  phase={s_phase:>10s}  eff={s_eff:>10s}  "
                          f"L{b['ligand']}:{b['name']}  (T={T.tolist()})")
                elif show_phase:
                    print(f"    raw={s_raw:>10s}  phase={s_phase:>10s}  "
                          f"L{b['ligand']}:{b['name']}  (T={T.tolist()})")
                else:
                    print(f"    {s_eff:>16s}  L{b['ligand']}:{b['name']}  (T={T.tolist()})")

            else:
                print(f"    {raw.real:+.4f}  L{b['ligand']}:{b['name']}")


# ============================================================
# 6. k-compatibility filter (periodic constraint)
# ============================================================

def filter_salcs_by_k(salcs, basis, site_id, cell_T, k_frac, tol=1e-6):
    """
    Check Bloch periodic compatibility for SALC coefficients across periodic images.

    Condition:
      For each site s, values  c_i * exp(i 2π k·T_i)  should be equal (up to global factor)
      for all ligands i belonging to site s.

    Returns:
      compatible_salcs (list of vectors)
      n_ok
    """
    site_id = np.array(site_id, int)
    cell_T = np.array(cell_T, int)
    k = np.array(k_frac, float)

    # group basis indices by ligand (because basis has multiple orbitals per ligand)
    # We'll enforce constraint per (ligand) aggregated amplitude for each orbital separately.
    # Practically: constraint is applied per basis coefficient directly.
    # That is fine because each basis coefficient belongs to one ligand.

    # precompute phase per ligand
    phase_lig = np.exp(2j * np.pi * (cell_T @ k))

    # sites -> ligand indices
    sites = {}
    for lid, sid in enumerate(site_id):
        sites.setdefault(int(sid), []).append(lid)

    compatible = []
    for vec in salcs:
        ok = True
        vec = np.array(vec, float)

        # For each site, check all basis components on ligands in that site
        # For each basis function type/name, compare the phase-corrected coefficient.
        # Easiest: compare per basis coefficient after multiplying by ligand phase.
        # Pick first occurrence as reference for that site & that basis-orbital label (type+name).
        ref = {}  # key=(site, type, name) -> complex ref value
        for bi, b in enumerate(basis):
            lid = b["ligand"]
            sid = int(site_id[lid])
            key = (sid, b["type"], b["name"])

            val = complex(vec[bi], 0.0) * phase_lig[lid]

            if key not in ref:
                ref[key] = val
            else:
                if abs(val - ref[key]) > tol:
                    ok = False
                    break
        if ok:
            compatible.append(vec)

    return compatible, len(compatible)


# ============================================================
# 7. Symmetry hit rate (for autodetect)
# ============================================================

def symmetry_hit_rate(group, ligands_pos, tol_match=1e-3):
    """
    Count how many group operations map the ligand set onto itself within tol_match.
    (We test only positions, not orbital orientations.)
    Returns (hits, rate).
    """
    ligands_pos = np.array(ligands_pos, float)
    N = ligands_pos.shape[0]
    hits = 0

    for g in group:
        R = g["R"]
        ok = True
        for i in range(N):
            p_new = R @ ligands_pos[i]
            j = match_ligand_id_by_position(ligands_pos, p_new, tol=tol_match)
            if j is None:
                ok = False
                break
        if ok:
            hits += 1

    return hits, hits / len(group) if len(group) > 0 else 0.0
