# salc_run.py
import json
import argparse
import numpy as np

import tkpg
from tkpg import core


def load_structure_json(path):
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)

    # --- lattice / positions ---
    lattice = np.array(obj.get("lattice", np.eye(3)), float)

    if "ligands_frac" in obj:
        ligands_frac = np.array(obj["ligands_frac"], float)
        # keep as-is; core will build ligands_pos + periodic metadata
        ligands_pos = None
    elif "ligands_pos" in obj:
        ligands_pos = np.array(obj["ligands_pos"], float)
        ligands_frac = None
    else:
        raise ValueError("structure.json must contain either 'ligands_frac' or 'ligands_pos'.")

    d_orbital = obj.get("d_orbital", "d_xy")
    mode = obj.get("mode", "full")

    center_ligands = bool(obj.get("center_ligands", False))

    # --- tolerances ---
    tol = obj.get("tolerances", {})
    tol_match_geom = float(tol.get("tol_match_geom", 5e-3))
    tol_match_D = float(tol.get("tol_match_D", 5e-3))
    tol_frac_site = float(tol.get("tol_frac_site", 1e-6))  # periodic site clustering in frac

    # --- SALC config ---
    salc_cfg = obj.get("salc", {})
    cfg = {
        "lattice": lattice,
        "ligands_frac": ligands_frac,
        "ligands_pos": ligands_pos,
        "center_ligands": center_ligands,
        "d_orbital": d_orbital,
        "mode": mode,
        "tol_match_geom": tol_match_geom,
        "tol_match_D": tol_match_D,
        "tol_frac_site": tol_frac_site,
        "salc_eig_tol": float(salc_cfg.get("salc_eig_tol", 1e-6)),
        "coeff_tol": float(salc_cfg.get("coeff_tol", 1e-6)),
        "print_coeffs": bool(salc_cfg.get("print_coeffs", True)),
        "print_salc_thr": float(salc_cfg.get("print_salc_thr", 1e-3)),
        "max_salc_per_irrep": salc_cfg.get("max_salc_per_irrep", None),
        "autodetect": obj.get("autodetect", {}),
        "kpoints_frac": obj.get("kpoints_frac", [[0.0, 0.0, 0.0]]),
        "k_tol": float(obj.get("k_tol", 1e-6)),
    }
    return cfg


def choose_point_group(plugins, ligands_pos, tol_match_geom, autodetect_cfg):
    """
    各プラグインの align_guess + hit_rate を評価して最良を採用
    tie-break: rate -> hits -> group order -> name
    """
    min_rate = (autodetect_cfg.get("min_rate_strict", {}) if autodetect_cfg else {})
    eps_tie = float((autodetect_cfg.get("tie_eps", 1e-12)) if autodetect_cfg else 1e-12)

    scored = []
    for p in plugins:
        ok, R_align, pos_aligned = p["align_guess"](ligands_pos)
        group = p["build_group"]()
        hits, rate = core.symmetry_hit_rate(group, pos_aligned if ok else ligands_pos, tol_match=tol_match_geom)
        scored.append((rate, hits, ok, p, R_align, pos_aligned))

    scored.sort(key=lambda x: (x[0], x[1], len(x[3]["build_group"]()), x[3]["name"]), reverse=True)
    best_rate, best_hits, best_ok, best_p, best_R, best_pos = scored[0]

    # strict acceptance if specified
    thr = float(min_rate.get(best_p["name"], 0.0))
    strict_tag = ""
    if best_rate < thr:
        strict_tag = f" STRICT (thr={thr:.2f})"

    # reformat for printing
    scored_pretty = []
    for rate, hits, ok, p, R, pos in scored:
        thrp = float(min_rate.get(p["name"], 0.0))
        tag = " OK"
        if rate < thrp:
            tag = f" STRICT (thr={thrp:.2f})"
        scored_pretty.append((rate, hits, ok, p, R, pos, tag))

    return (best_p, best_rate, best_hits, best_R, best_pos, strict_tag), scored_pretty


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--infile", default="structure_k.json", help="structure JSON file (default: structure.json)")
    args = ap.parse_args()

    cfg = load_structure_json(args.infile)

    # ------------------------------------------------------------
    # Build geometry + periodic metadata
    # ------------------------------------------------------------
    lattice = cfg["lattice"]

    if cfg["ligands_frac"] is not None:
        meta = core.build_periodic_metadata_from_frac(
            ligands_frac=cfg["ligands_frac"],
            lattice=lattice,
            tol_frac=cfg["tol_frac_site"],
        )
        ligands_pos = meta["ligands_pos"]  # geometry vectors (cart)
        site_id = meta["site_id"]
        cell_T = meta["cell_T"]
        frac_wrap = meta["frac_wrap"]
    else:
        ligands_pos = np.array(cfg["ligands_pos"], float)
        # non-periodic fallback: every ligand is its own site, T=0
        site_id = np.arange(len(ligands_pos), dtype=int)
        cell_T = np.zeros((len(ligands_pos), 3), dtype=int)
        frac_wrap = None

    # Optional centering (NOT recommended for periodic use)
    if cfg["center_ligands"]:
        ligands_pos = core.center_positions(ligands_pos)

    # ------------------------------------------------------------
    # Load plugins & choose point group
    # ------------------------------------------------------------
    plugins = tkpg.load_plugins()
    if not plugins:
        raise RuntimeError("No tkpg plugins found. Ensure tkpg/Oh.py, tkpg/C4v.py, etc exist.")

    best, scored = choose_point_group(
        plugins, ligands_pos, cfg["tol_match_geom"], cfg["autodetect"]
    )
    plugin, rate, hits, R_align, pos_aligned, strict_tag = best

    print(f"\n[Auto-detected point group]  {plugin['name']}   (hits={hits}/{len(plugin['build_group']())}, hit_rate={rate:.3f}){strict_tag}")
    print("[Candidates]")
    for r, h, ok, p, *_rest in scored:
        tag = _rest[-1]
        print(f"  {p['name']:>4s} : hits={h:2d}/{len(p['build_group']()):<2d}  rate={r:.3f}  ok_align={ok}{tag}")

    # ------------------------------------------------------------
    # Align geometry for canonical axis conventions
    # NOTE: periodic metadata (site_id, cell_T) is defined from input frac,
    #       and remains valid; alignment is for symmetry operations only.
    # ------------------------------------------------------------
    pos_aligned = (R_align @ ligands_pos.T).T

    # ------------------------------------------------------------
    # Center d orbital irrep
    # ------------------------------------------------------------
    d_orbital = cfg["d_orbital"]
    d_orbital_use = "d_xz" if d_orbital == "d_zx" else d_orbital

    dmap = plugin["d_irreps"]
    if d_orbital_use not in dmap:
        raise ValueError(f"d orbital '{d_orbital}' not defined for point group {plugin['name']}")
    d_irrep = dmap[d_orbital_use]

    print("\n[Central metal d orbital]")
    print(f"  Orbital : {d_orbital}")
    print(f"  Irrep   : {d_irrep}")

    # ------------------------------------------------------------
    # Build basis
    # ------------------------------------------------------------
    mode = cfg["mode"]
    frames = core.compute_local_frames(pos_aligned)
    basis = core.build_ligand_basis(len(pos_aligned), frames, mode=mode)

    group = plugin["build_group"]()
    irreps = plugin["irreps"]()
    irrep_dim = plugin["irrep_dim"]
    irrep_char = plugin["irrep_char"]

    # ------------------------------------------------------------
    # Reducible rep decomposition (point group only; k-independent)
    # ------------------------------------------------------------
    Gamma = core.build_reducible_characters(
        group, pos_aligned, frames, basis, mode=mode, tol_match=cfg["tol_match_D"]
    )
    coeffs = core.decompose_generic(Gamma, group, irreps, irrep_char)

    if cfg["print_coeffs"]:
        print(f"\n[Ligand reducible rep decomposition ({plugin['name']})]")
        for ir in irreps:
            if coeffs.get(ir, 0.0) > cfg["coeff_tol"]:
                print(f"  {ir}: {coeffs[ir]:.6f}")

    # Symmetry-only hybridization check
    print("\n[Hybridization check (symmetry-only)]")
    if coeffs.get(d_irrep, 0.0) > cfg["coeff_tol"]:
        print(f"  ✔ Allowed by point-group symmetry: ligand contains {d_irrep}")
    else:
        print(f"  ✘ Forbidden by point-group symmetry: ligand does NOT contain {d_irrep}")

    # ------------------------------------------------------------
    # SALCs (point-group) -> then k-compatibility filtering per k
    # ------------------------------------------------------------
    kpoints = np.array(cfg["kpoints_frac"], float)
    k_tol = float(cfg["k_tol"])
    
    print(f"\n[All ligand SALCs with Bloch phase]  point_group={plugin['name']}  mode={mode}")
    for ir in irreps:
        mult = coeffs.get(ir, 0.0)
        if mult <= cfg["coeff_tol"]:
            continue

        # Extract point-group SALCs (real vectors)
        salcs = core.extract_salc(
            ir, group,
            irrep_dim, irrep_char,
            pos_aligned, frames, basis, mode=mode,
            tol_match=cfg["tol_match_D"],
            tol=cfg["salc_eig_tol"]
        )

        if len(salcs) == 0:
            print(f"\n[SALCs for {ir}]  (expected mult≈{mult:.3f}, but none extracted; try lowering salc_eig_tol)")
            continue

        # For each k: filter by k-compatibility, then print with phase
        for k_frac in kpoints:
            k_frac = np.array(k_frac, float)
            compatible, n_ok = core.filter_salcs_by_k(
                salcs=salcs,
                basis=basis,
                site_id=site_id,
                cell_T=cell_T,
                k_frac=k_frac,
                tol=k_tol
            )
            print(f"\n[k-compatibility] irrep={ir}  k_frac=({k_frac[0]:.3f},{k_frac[1]:.3f},{k_frac[2]:.3f})  compatible={n_ok}/{len(salcs)}")

            if n_ok == 0:
                print(f"  (No k-compatible SALCs for irrep {ir} at k=({k_frac[0]:.3f},{k_frac[1]:.3f},{k_frac[2]:.3f}))")
                continue

            core.print_salc(
                ir, compatible, basis,
                thr=cfg["print_salc_thr"],
                max_salc=cfg["max_salc_per_irrep"],
                k_frac=k_frac,
                cell_T=cell_T
            )


if __name__ == "__main__":
    main()
