"""
このモジュールは、結晶構造の対称性を解析し、特に中心金属のd軌道と配位子の対称適応型線形結合 (SALC) のk点における互換性を評価します。
ユーザーが指定した構造JSONファイルから情報を読み込み、点群の自動検出、SALCの計算、およびk点におけるSALCの振る舞いを解析します。

関連リンク: :doc:`analyze_symmetry_k_usage`
"""

import json
import argparse
import numpy as np

import tkpg
from tkpg import core


def load_structure_json(path):
    """
    構造設定をJSONファイルから読み込み、設定辞書を構築します。

    JSONファイルから格子、配位子座標、d軌道情報、解析モード、許容誤差、SALC設定などを読み込みます。
    `ligands_frac` (分数座標) または `ligands_pos` (デカルト座標) のいずれかが必須です。

    :param path: 設定を読み込むJSONファイルへのパス。
    :type path: str
    :returns: 読み込んだ設定情報を含む辞書。
    :rtype: dict
    :raises ValueError: JSONファイルに 'ligands_frac' または 'ligands_pos' のいずれも含まれていない場合。
    """
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)

    # --- lattice / positions ---
    lattice = np.array(obj.get("lattice", np.eye(3)), float)

    if "ligands_frac" in obj:
        ligands_frac = np.array(obj["ligands_frac"], float)
        # keep as-is; core will build ligands_pos + periodic metadata
        ligands_pos = None
    elif "ligands_pos" in obj:
        ligands_pos = np.array(obj["ligands_pos"], float)
        ligands_frac = None
    else:
        raise ValueError("structure.json must contain either 'ligands_frac' or 'ligands_pos'.")

    d_orbital = obj.get("d_orbital", "d_xy")
    mode = obj.get("mode", "full")

    center_ligands = bool(obj.get("center_ligands", False))

    # --- tolerances ---
    tol = obj.get("tolerances", {})
    tol_match_geom = float(tol.get("tol_match_geom", 5e-3))
    tol_match_D = float(tol.get("tol_match_D", 5e-3))
    tol_frac_site = float(tol.get("tol_frac_site", 1e-6))  # periodic site clustering in frac

    # --- SALC config ---
    salc_cfg = obj.get("salc", {})
    cfg = {
        "lattice": lattice,
        "ligands_frac": ligands_frac,
        "ligands_pos": ligands_pos,
        "center_ligands": center_ligands,
        "d_orbital": d_orbital,
        "mode": mode,
        "tol_match_geom": tol_match_geom,
        "tol_match_D": tol_match_D,
        "tol_frac_site": tol_frac_site,
        "salc_eig_tol": float(salc_cfg.get("salc_eig_tol", 1e-6)),
        "coeff_tol": float(salc_cfg.get("coeff_tol", 1e-6)),
        "print_coeffs": bool(salc_cfg.get("print_coeffs", True)),
        "print_salc_thr": float(salc_cfg.get("print_salc_thr", 1e-3)),
        "max_salc_per_irrep": salc_cfg.get("max_salc_per_irrep", None),
        "autodetect": obj.get("autodetect", {}),
        "kpoints_frac": obj.get("kpoints_frac", [[0.0, 0.0, 0.0]]),
        "k_tol": float(obj.get("k_tol", 1e-6)),
    }
    return cfg


def choose_point_group(plugins, ligands_pos, tol_match_geom, autodetect_cfg):
    """
    複数の点群プラグインを評価し、配位子の配置に最も適切な点群を自動検出します。

    各プラグインの `align_guess` と `symmetry_hit_rate` を利用して適合度を計算し、
    ヒット率、ヒット数、群の位数、名前の順でソートして最良のものを選択します。
    `autodetect_cfg` に `min_rate_strict` が指定されている場合、特定の点群に対して
    厳密なヒット率の閾値が適用されます。

    :param plugins: 評価する `tkpg` プラグインのリスト。各プラグインは、点群の名前、
                    `align_guess` メソッド、`build_group` メソッドなどを含む辞書である必要があります。
    :type plugins: list[dict]
    :param ligands_pos: 配位子の座標 (デカルト座標) のNumPy配列。形状は (N, 3) です。
    :type ligands_pos: numpy.ndarray
    :param tol_match_geom: 幾何学的なマッチングに使用する許容誤差。
    :type tol_match_geom: float
    :param autodetect_cfg: 自動検出に関する追加設定を含む辞書。
                           `min_rate_strict` (点群名と閾値の辞書) や `tie_eps` (タイブレークの許容誤差)
                           などのキーを含めることができます。`None` の場合、デフォルト値が使用されます。
    :type autodetect_cfg: dict or None
    :returns: 以下の2つの要素を含むタプルを返します。
              - `best`: 検出された最良の点群に関する情報。以下の要素を含むタプルです。
                - `plugin`: 最良と判断された点群プラグイン辞書。
                - `rate`: 最良点群のヒット率。
                - `hits`: 最良点群のヒット数。
                - `R_align`: 配位子を標準的な軸にアラインメントするための回転行列。
                - `pos_aligned`: アラインメント後の配位子座標。
                - `strict_tag`: 厳密な受け入れ基準が適用された場合のタグ文字列。
              - `scored_pretty`: 評価された全プラグインの情報（整形済み文字列含む）のリスト。
                                各要素は `(rate, hits, ok, p, R, pos, tag)` のタプルです。
    :rtype: tuple[tuple[dict, float, int, numpy.ndarray, numpy.ndarray, str], list[tuple]]
    """
    min_rate = (autodetect_cfg.get("min_rate_strict", {}) if autodetect_cfg else {})
    eps_tie = float((autodetect_cfg.get("tie_eps", 1e-12)) if autodetect_cfg else 1e-12)

    scored = []
    for p in plugins:
        ok, R_align, pos_aligned = p["align_guess"](ligands_pos)
        group = p["build_group"]()
        hits, rate = core.symmetry_hit_rate(group, pos_aligned if ok else ligands_pos, tol_match=tol_match_geom)
        scored.append((rate, hits, ok, p, R_align, pos_aligned))

    scored.sort(key=lambda x: (x[0], x[1], len(x[3]["build_group"]()), x[3]["name"]), reverse=True)
    best_rate, best_hits, best_ok, best_p, best_R, best_pos = scored[0]

    # strict acceptance if specified
    thr = float(min_rate.get(best_p["name"], 0.0))
    strict_tag = ""
    if best_rate < thr:
        strict_tag = f" STRICT (thr={thr:.2f})"

    # reformat for printing
    scored_pretty = []
    for rate, hits, ok, p, R, pos in scored:
        thrp = float(min_rate.get(p["name"], 0.0))
        tag = " OK"
        if rate < thrp:
            tag = f" STRICT (thr={thrp:.2f})"
        scored_pretty.append((rate, hits, ok, p, R, pos, tag))

    return (best_p, best_rate, best_hits, best_R, best_pos, strict_tag), scored_pretty


def main():
    """
    コマンドライン引数からJSONファイルを読み込み、配位子SALCのk点互換性解析を実行します。

    主な処理フローは以下の通りです。
    1. 構造JSONを読み込み、解析に必要な設定を抽出します。
    2. 周期的なメタデータ（`ligands_frac` が指定された場合）を構築し、必要に応じて配位子を中央に配置します。
    3. `tkpg` プラグインをロードし、`choose_point_group` を使用して最適な点群を自動検出します。
    4. 検出された点群の標準的な軸の慣例に従って配位子をアラインメントします。
    5. 中心金属のd軌道に対応する既約表現を特定します。
    6. 配位子の局所フレームと基底を構築します。
    7. 点群の既約表現分解（k点独立）を実行し、d軌道とのハイブリダイゼーションの可能性をチェックします。
    8. 各指定k点について、点群SALCを抽出し、Bloch位相を考慮したk点互換性フィルターを適用し、結果を出力します。
    """
    ap = argparse.ArgumentParser()
    ap.add_argument("--infile", default="structure_k.json", help="structure JSON file (default: structure.json)")
    args = ap.parse_args()

    cfg = load_structure_json(args.infile)

    # ------------------------------------------------------------
    # Build geometry + periodic metadata
    # ------------------------------------------------------------
    lattice = cfg["lattice"]

    if cfg["ligands_frac"] is not None:
        meta = core.build_periodic_metadata_from_frac(
            ligands_frac=cfg["ligands_frac"],
            lattice=lattice,
            tol_frac=cfg["tol_frac_site"],
        )
        ligands_pos = meta["ligands_pos"]  # geometry vectors (cart)
        site_id = meta["site_id"]
        cell_T = meta["cell_T"]
        frac_wrap = meta["frac_wrap"]
    else:
        ligands_pos = np.array(cfg["ligands_pos"], float)
        # non-periodic fallback: every ligand is its own site, T=0
        site_id = np.arange(len(ligands_pos), dtype=int)
        cell_T = np.zeros((len(ligands_pos), 3), dtype=int)
        frac_wrap = None

    # Optional centering (NOT recommended for periodic use)
    if cfg["center_ligands"]:
        ligands_pos = core.center_positions(ligands_pos)

    # ------------------------------------------------------------
    # Load plugins & choose point group
    # ------------------------------------------------------------
    plugins = tkpg.load_plugins()
    if not plugins:
        raise RuntimeError("No tkpg plugins found. Ensure tkpg/Oh.py, tkpg/C4v.py, etc exist.")

    best, scored = choose_point_group(
        plugins, ligands_pos, cfg["tol_match_geom"], cfg["autodetect"]
    )
    plugin, rate, hits, R_align, pos_aligned, strict_tag = best

    print(f"\n[Auto-detected point group]  {plugin['name']}   (hits={hits}/{len(plugin['build_group']())}, hit_rate={rate:.3f}){strict_tag}")
    print("[Candidates]")
    for r, h, ok, p, *_rest in scored:
        tag = _rest[-1]
        print(f"  {p['name']:>4s} : hits={h:2d}/{len(p['build_group']()):<2d}  rate={r:.3f}  ok_align={ok}{tag}")

    # ------------------------------------------------------------
    # Align geometry for canonical axis conventions
    # NOTE: periodic metadata (site_id, cell_T) is defined from input frac,
    #       and remains valid; alignment is for symmetry operations only.
    # ------------------------------------------------------------
    pos_aligned = (R_align @ ligands_pos.T).T

    # ------------------------------------------------------------
    # Center d orbital irrep
    # ------------------------------------------------------------
    d_orbital = cfg["d_orbital"]
    d_orbital_use = "d_xz" if d_orbital == "d_zx" else d_orbital

    dmap = plugin["d_irreps"]
    if d_orbital_use not in dmap:
        raise ValueError(f"d orbital '{d_orbital}' not defined for point group {plugin['name']}")
    d_irrep = dmap[d_orbital_use]

    print("\n[Central metal d orbital]")
    print(f"  Orbital : {d_orbital}")
    print(f"  Irrep   : {d_irrep}")

    # ------------------------------------------------------------
    # Build basis
    # ------------------------------------------------------------
    mode = cfg["mode"]
    frames = core.compute_local_frames(pos_aligned)
    basis = core.build_ligand_basis(len(pos_aligned), frames, mode=mode)

    group = plugin["build_group"]()
    irreps = plugin["irreps"]()
    irrep_dim = plugin["irrep_dim"]
    irrep_char = plugin["irrep_char"]

    # ------------------------------------------------------------
    # Reducible rep decomposition (point group only; k-independent)
    # ------------------------------------------------------------
    Gamma = core.build_reducible_characters(
        group, pos_aligned, frames, basis, mode=mode, tol_match=cfg["tol_match_D"]
    )
    coeffs = core.decompose_generic(Gamma, group, irreps, irrep_char)

    if cfg["print_coeffs"]:
        print(f"\n[Ligand reducible rep decomposition ({plugin['name']})]")
        for ir in irreps:
            if coeffs.get(ir, 0.0) > cfg["coeff_tol"]:
                print(f"  {ir}: {coeffs[ir]:.6f}")

    # Symmetry-only hybridization check
    print("\n[Hybridization check (symmetry-only)]")
    if coeffs.get(d_irrep, 0.0) > cfg["coeff_tol"]:
        print(f"  ✔ Allowed by point-group symmetry: ligand contains {d_irrep}")
    else:
        print(f"  ✘ Forbidden by point-group symmetry: ligand does NOT contain {d_irrep}")

    # ------------------------------------------------------------
    # SALCs (point-group) -> then k-compatibility filtering per k
    # ------------------------------------------------------------
    kpoints = np.array(cfg["kpoints_frac"], float)
    k_tol = float(cfg["k_tol"])
    
    print(f"\n[All ligand SALCs with Bloch phase]  point_group={plugin['name']}  mode={mode}")
    for ir in irreps:
        mult = coeffs.get(ir, 0.0)
        if mult <= cfg["coeff_tol"]:
            continue

        # Extract point-group SALCs (real vectors)
        salcs = core.extract_salc(
            ir, group,
            irrep_dim, irrep_char,
            pos_aligned, frames, basis, mode=mode,
            tol_match=cfg["tol_match_D"],
            tol=cfg["salc_eig_tol"]
        )

        if len(salcs) == 0:
            print(f"\n[SALCs for {ir}]  (expected mult≈{mult:.3f}, but none extracted; try lowering salc_eig_tol)")
            continue

        # For each k: filter by k-compatibility, then print with phase
        for k_frac in kpoints:
            k_frac = np.array(k_frac, float)
            compatible, n_ok = core.filter_salcs_by_k(
                salcs=salcs,
                basis=basis,
                site_id=site_id,
                cell_T=cell_T,
                k_frac=k_frac,
                tol=k_tol
            )
            print(f"\n[k-compatibility] irrep={ir}  k_frac=({k_frac[0]:.3f},{k_frac[1]:.3f},{k_frac[2]:.3f})  compatible={n_ok}/{len(salcs)}")

            if n_ok == 0:
                print(f"  (No k-compatible SALCs for irrep {ir} at k=({k_frac[0]:.3f},{k_frac[1]:.3f},{k_frac[2]:.3f}))")
                continue

            core.print_salc(
                ir, compatible, basis,
                thr=cfg["print_salc_thr"],
                max_salc=cfg["max_salc_per_irrep"],
                k_frac=k_frac,
                cell_T=cell_T
            )


if __name__ == "__main__":
    main()