# salc_run.py
import json
import argparse
import numpy as np

import tkpg
from tkpg import core


def load_structure_json(path):
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)

    ligands_pos = np.array(obj["ligands_pos"], float)
    d_orbital = obj.get("d_orbital", "d_xy")
    mode = obj.get("mode", "full")

    tol = obj.get("tolerances", {})
    tol_match_geom = float(tol.get("tol_match_geom", 5e-3))
    tol_match_D = float(tol.get("tol_match_D", 5e-3))

    salc_cfg = obj.get("salc", {})
    cfg = {
        "ligands_pos": ligands_pos,
        "d_orbital": d_orbital,
        "mode": mode,
        "tol_match_geom": tol_match_geom,
        "tol_match_D": tol_match_D,
        "salc_eig_tol": float(salc_cfg.get("salc_eig_tol", 1e-6)),
        "coeff_tol": float(salc_cfg.get("coeff_tol", 1e-6)),
        "print_coeffs": bool(salc_cfg.get("print_coeffs", True)),
        "print_salc_thr": float(salc_cfg.get("print_salc_thr", 1e-3)),
        "max_salc_per_irrep": salc_cfg.get("max_salc_per_irrep", None),
        "autodetect": obj.get("autodetect", {}),
    }
    return cfg


def center_positions(ligands_pos):
    return ligands_pos - np.mean(ligands_pos, axis=0, keepdims=True)


def choose_point_group(plugins, ligands_pos, tol_match_geom, autodetect_cfg):
    """
    各プラグインの align_guess + symmetry_hits を評価して最良を採用

    重要:
      - 小さい群ほど rate が上がりやすいので、rate最大ではなく
        「hits最大」→「|G|最大」→「rate最大」で選ぶ
    """
    autodetect_cfg = autodetect_cfg or {}
    min_rate = autodetect_cfg.get("min_rate_strict", {})  # 例: {"Oh":0.70, "C4v":0.80}

    scored = []
    for p in plugins:
        ok, R_align, pos_aligned = p["align_guess"](ligands_pos)
        group = p["build_group"]()
        Gorder = len(group)

        if ok:
#            hits, rate = core.symmetry_hits(group, pos_aligned, tol_match=tol_match_geom)
            hits, rate = core.symmetry_hit_rate(group, pos_aligned, tol_match=tol_match_geom)
        else:
            hits, rate = 0, 0.0

        # strict acceptance flag (rate based: backward compatible)
        thr = float(min_rate.get(p["name"], 0.0))
        strict_ok = (rate >= thr)

        scored.append({
            "plugin": p,
            "ok_align": ok,
            "R_align": R_align,
            "pos_aligned": pos_aligned,
            "hits": hits,
            "rate": rate,
            "Gorder": Gorder,
            "strict_thr": thr,
            "strict_ok": strict_ok,
        })

    # ---- selection rule ----
    # 1) Prefer strict_ok True (if any)
    strict = [s for s in scored if s["strict_ok"]]
    pool = strict if strict else scored

    # 2) Sort by hits desc, then |G| desc, then rate desc
    pool.sort(key=lambda s: (s["hits"], s["Gorder"], s["rate"]), reverse=True)

    best = pool[0]
    return best, scored


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--infile", default="structure.json",
                    help="structure JSON file (default: structure.json)")
    args = ap.parse_args()

    cfg = load_structure_json(args.infile)

    ligands_pos = center_positions(cfg["ligands_pos"])
    d_orbital = cfg["d_orbital"]
    mode = cfg["mode"]

    tol_match_geom = cfg["tol_match_geom"]
    tol_match_D = cfg["tol_match_D"]

    plugins = tkpg.load_plugins()
    if not plugins:
        raise RuntimeError("No tkpg plugins found. Ensure tkpg/Oh.py, tkpg/C4v.py exist.")

    best, scored = choose_point_group(
        plugins, ligands_pos, tol_match_geom, cfg["autodetect"]
    )

    plugin = best["plugin"]
    rate = best["rate"]
    hits = best["hits"]
    Gorder = best["Gorder"]
    R_align = best["R_align"]
    pos_aligned = best["pos_aligned"]

    print(f"\n[Auto-detected point group]  {plugin['name']}   (hits={hits}/{Gorder}, hit_rate={rate:.3f})")

    print("[Candidates]")
    # show all candidates, sorted by the SAME selection key (but keep strict_ok shown)
    scored_sorted = sorted(scored, key=lambda s: (s["strict_ok"], s["hits"], s["Gorder"], s["rate"]), reverse=True)
    for s in scored_sorted:
        p = s["plugin"]
        flag = "STRICT" if s["strict_ok"] else "     "
        thr = s["strict_thr"]
        print(f"  {p['name']:>4s} : hits={s['hits']:>2d}/{s['Gorder']:<2d}  rate={s['rate']:.3f}  "
              f"ok_align={s['ok_align']}  {flag} (thr={thr:.2f})")

    # d_zx alias
    d_orbital_use = "d_xz" if d_orbital == "d_zx" else d_orbital
    dmap = plugin["d_irreps"]
    if d_orbital_use not in dmap:
        raise ValueError(f"d orbital '{d_orbital}' not defined for point group {plugin['name']}")

    d_irrep = dmap[d_orbital_use]
    print("\n[Central metal d orbital]")
    print(f"  Orbital : {d_orbital}")
    print(f"  Irrep   : {d_irrep}")

    # Build basis
    frames = core.compute_local_frames(pos_aligned)
    basis  = core.build_ligand_basis(len(pos_aligned), frames, mode=mode)

    group = plugin["build_group"]()
    irreps = plugin["irreps"]()
    irrep_dim = plugin["irrep_dim"]
    irrep_char = plugin["irrep_char"]

    # Decompose
    Gamma = core.build_reducible_characters(group, pos_aligned, frames, basis, mode=mode, tol_match=tol_match_D)
    coeffs = core.decompose_generic(Gamma, group, irreps, irrep_char)

    if cfg["print_coeffs"]:
        print(f"\n[Ligand reducible rep decomposition ({plugin['name']})]")
        for ir in irreps:
            if coeffs.get(ir, 0.0) > cfg["coeff_tol"]:
                print(f"  {ir}: {coeffs[ir]:.3f}")

    # Hybridization check
    print("\n[Hybridization check]")
    if coeffs.get(d_irrep, 0.0) > cfg["coeff_tol"]:
        print(f"  ✔ Hybridization allowed by symmetry: ligand contains {d_irrep}")
    else:
        print(f"  ✘ Symmetry forbidden: ligand does NOT contain {d_irrep}")

    # SALCs: show ALL irreps
    print(f"\n[All ligand SALCs ({plugin['name']})]  mode={mode}")
    any_printed = False
    for ir in irreps:
        mult = coeffs.get(ir, 0.0)
        if mult <= cfg["coeff_tol"]:
            continue

        salcs = core.extract_salc(
            ir, group,
            irrep_dim, irrep_char,
            pos_aligned, frames, basis, mode=mode,
            tol_match=tol_match_D, tol=cfg["salc_eig_tol"]
        )

        if len(salcs) == 0:
            print(f"\n[SALCs for {ir}]  (expected mult≈{mult:.3f}, but none extracted; try lowering salc_eig_tol)")
            continue

        any_printed = True
        core.print_salc(
            ir, salcs, basis,
            thr=cfg["print_salc_thr"],
            max_salc=cfg["max_salc_per_irrep"]
        )

    if not any_printed:
        print("  (No SALCs extracted. Try increasing tol_match_D or lowering salc_eig_tol.)")


if __name__ == "__main__":
    main()
