#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import json
import argparse
import numpy as np
from collections import deque
from pymatgen.symmetry.groups import PointGroup

# 32種（H–M）
PG_HM = [
    "1", "-1", "2", "m", "2/m", "222", "mm2", "mmm",
    "4", "-4", "4/m", "422", "4mm", "-42m", "4/mmm",
    "3", "-3", "32", "3m", "-3m",
    "6", "-6", "6/m", "622", "6mm", "-6m2", "6/mmm",
    "23", "m-3", "432", "-43m", "m-3m"
]

EPS = 1e-12
Z = np.array([0.0, 0.0, 1.0])
X = np.array([1.0, 0.0, 0.0])
Y = np.array([0.0, 1.0, 0.0])

# ---------- 数値安定化（det 符号保存） ----------
def orthogonalize_preserving_det(R):
    U, _, Vt = np.linalg.svd(R)
    R_sv = U @ Vt
    det_target = np.sign(np.linalg.det(R)) or 1.0
    det_sv = np.sign(np.linalg.det(R_sv)) or 1.0
    if det_sv != det_target:
        U[:, -1] *= -1
        R_sv = U @ Vt
    return R_sv

def canon(R, ndigits=12):
    R_ = orthogonalize_preserving_det(R)
    return tuple(np.round(R_.reshape(-1), ndigits))

# ---------- 分類＆幾何 ----------
def rotation_order(R):
    R = orthogonalize_preserving_det(R)
    cos_th = np.clip((np.trace(R) - 1) / 2, -1, 1)
    th = np.arccos(cos_th)
    if th < 1e-8:
        return 1
    return int(np.round(2 * np.pi / th))

def eigvec_for_value(R, val, atol=1e-6):
    vals, vecs = np.linalg.eig(R)
    idx = np.where(np.isclose(vals.real, val, atol=atol))[0]
    if idx.size == 0:
        return None
    v = vecs[:, idx[0]].real
    n = np.linalg.norm(v)
    return v / n if n > EPS else None

def classify_and_geometry(R):
    """
    kind: 'I','Inv','Rot','Mir','Rinv','Unknown'
    order: 回転次数（Rot/Rinv のとき）
    axis_or_normal: Rot は回転軸, Mir は面法線
    """
    R = orthogonalize_preserving_det(R)
    det = np.linalg.det(R)
    if np.allclose(R, np.eye(3), atol=1e-7):  return 'I', 0, None
    if np.allclose(R, -np.eye(3), atol=1e-7): return 'Inv', 0, None
    if np.isclose(det, 1.0, atol=1e-7):
        n = rotation_order(R)
        if n == 1: return 'I', 0, None
        axis = eigvec_for_value(R, 1.0)
        return 'Rot', n, axis
    if np.isclose(det, -1.0, atol=1e-7):
        if np.allclose(R @ R, np.eye(3), atol=1e-6):
            nrm = eigvec_for_value(R, -1.0)
            return 'Mir', 0, nrm
        n = rotation_order(-R)
        axis = eigvec_for_value(-R, 1.0)
        return 'Rinv', n, axis
    return 'Unknown', 0, None

def axis_label(v):
    if v is None: return "?"
    v = v / (np.linalg.norm(v) + 1e-15)
    for name, b in [('z', Z), ('x', X), ('y', Y)]:
        if np.allclose(np.abs(np.dot(v, b)), 1.0, atol=1e-3):
            return name
    return f"{np.round(v[0],2)},{np.round(v[1],2)},{np.round(v[2],2)}"

def default_label_of(R):
    kind, n, v = classify_and_geometry(R)
    if kind == 'I':   return "E"
    if kind == 'Inv': return "i"
    if kind == 'Rot': return f"C{n}({axis_label(v)})" if v is not None else f"C{n}"
    if kind == 'Mir': return f"m(⊥{axis_label(v)})"   if v is not None else "m"
    if kind == 'Rinv':return f"S{n}({axis_label(v)})" if v is not None else f"S{n}"
    return "?"

# ---------- “対称性優先”スコア（ジェネレータ抽出用） ----------
def symmetry_priority_score(R):
    kind, n, v = classify_and_geometry(R)
    score = 0.0
    if kind == 'Rot' and v is not None:
        az, ax, ay = abs(np.dot(v, Z)), abs(np.dot(v, X)), abs(np.dot(v, Y))
        if az > 0.999:            # z軸回転（xy面に垂直）
            score = 1000 + 10 * n  # 高次数ほど優先
        elif ax > 0.999 or ay > 0.999:
            score = 600 + 2 * n
        else:
            score = 200 + n
    elif kind == 'Mir' and v is not None:
        if abs(np.dot(v, Z)) > 0.999:       # xy面鏡映
            score = 900
        elif abs(np.dot(v, X)) > 0.999 or abs(np.dot(v, Y)) > 0.999:
            score = 550
        else:
            score = 150
    elif kind == 'Rinv' and v is not None:
        score = 120 + n
    elif kind == 'Inv':
        score = 50
    elif kind == 'I':
        score = 0
    else:
        score = 10
    score += 0.001 * np.trace(orthogonalize_preserving_det(R))  # タイブレーク
    return score

# ---------- “既知要素へのスナップ”（常に最近傍に写す） ----------
def build_snapper(known_ops):
    known = [orthogonalize_preserving_det(R) for R in known_ops]
    known_keys = [canon(R) for R in known]
    known_arr = np.stack(known, axis=0)

    def snap(R):
        Rn = orthogonalize_preserving_det(R)
        diff = known_arr - Rn[None, :, :]
        dists = np.linalg.norm(diff.reshape(len(known), -1), axis=1)
        k = int(np.argmin(dists))
        return known[k], known_keys[k]
    return snap

# ---------- BFS 閉包（語つき、スナップ使用） ----------
def closure_with_words(generators, target_size, snap):
    """
    右掛け: new = current @ Gj
    語は「ジェネレータ番号列(list[int])」として保持（最短語が BFS で得られる）
    戻り値: mats, words_idx(list[list[int]]), key_to_idx
    """
    I = np.eye(3)
    I_snap, I_key = snap(I)
    mats = [I_snap]
    words_idx = [[]]  # 恒等元は空語
    key_to_idx = {I_key: 0}
    q = deque([0])

    gens_snap = [snap(G)[0] for G in generators]

    while q:
        i = q.popleft()
        A = mats[i]; wA = words_idx[i]
        for j, Gj in enumerate(gens_snap):
            B_raw = A @ Gj
            B, k = snap(B_raw)
            if k not in key_to_idx:
                key_to_idx[k] = len(mats)
                mats.append(B)
                words_idx.append(wA + [j])  # 右掛けでジェネレータ番号を追加
                q.append(len(mats) - 1)
                if len(mats) >= target_size:
                    return mats, words_idx, key_to_idx
    return mats, words_idx, key_to_idx

def compare_sets(setA, setB):
    return {canon(M) for M in setA} == {canon(M) for M in setB}

# ---------- generator 抽出 ----------
def find_generators(all_ops):
    ops_sorted = sorted(all_ops, key=lambda R: -symmetry_priority_score(R))
    snap = build_snapper(all_ops)

    gens = []
    mats, _, _ = closure_with_words(gens, target_size=len(all_ops), snap=snap)
    have = {canon(M) for M in mats}
    target = {canon(R) for R in all_ops}

    for R in ops_sorted:
        trial_gens = gens + [R]
        mats_t, _, _ = closure_with_words(trial_gens, target_size=len(all_ops), snap=snap)
        if len({canon(M) for M in mats_t}) > len(have):
            gens = trial_gens
            mats, _, _ = closure_with_words(gens, target_size=len(all_ops), snap=snap)
            have = {canon(M) for M in mats}
        if have == target:
            break

    # 冗長削除
    i = 0
    while i < len(gens):
        trial_g = gens[:i] + gens[i+1:]
        mats_t, _, _ = closure_with_words(trial_g, target_size=len(all_ops), snap=snap)
        if {canon(M) for M in mats_t} == target:
            gens = trial_g
        else:
            i += 1

    # 好みの順序：z軸回転, xy鏡映 を先頭へ
    def prefer_key(R):
        kind, n, v = classify_and_geometry(R)
        zrot = (kind == 'Rot' and v is not None and abs(np.dot(v, Z)) > 0.999)
        xym  = (kind == 'Mir' and v is not None and abs(np.dot(v, Z)) > 0.999)
        return (0 if zrot else 1 if xym else 2, -n, -symmetry_priority_score(R))

    order = sorted(range(len(gens)), key=lambda i: prefer_key(gens[i]))
    gens = [gens[i] for i in order]
    return gens, snap

# ---------- 位数の推定（各 generator ごと） ----------
def estimate_orders(generators, snap):
    """
    各ジェネレータの位数 n を推定（最小の正整数 k で G^k = E）。
    スナップして判定するので頑健。最大 12 まで探索（結晶点群なら十分）。
    """
    orders = []
    I = snap(np.eye(3))[0]
    for G in generators:
        A = snap(G)[0]
        cur = I
        for k in range(1, 13):
            cur = snap(cur @ A)[0]
            if canon(cur) == canon(I):
                orders.append(k)
                break
        else:
            orders.append(None)  # 見つからない（鏡映など位数2以外はほぼ無いはず）
    return orders

# ---------- 語の文字列化（指数表記、正の指数優先） ----------
def format_word(words_idx, gen_labels, gen_orders):
    """
    words_idx: [g_idx0, g_idx1, ...] 右掛け順
    連続する同一 generator は ^p に圧縮。位数 n が分かれば mod n に正規化し、負指数は使わない。
    """
    if not words_idx:
        return "E"

    parts = []
    i = 0
    while i < len(words_idx):
        j = words_idx[i]
        p = 1
        i += 1
        # 同一ジェネレータの連続を数える
        while i < len(words_idx) and words_idx[i] == j:
            p += 1
            i += 1
        # 位数で正規化
        n = gen_orders[j]
        if n is not None and n > 0:
            p = p % n
            if p == 0:
                # 完全に消える（E） -> 何も足さない
                continue
        # 表示
        lbl = gen_labels[j]
        if p == 1:
            parts.append(f"{lbl}")
        else:
            parts.append(f"{lbl}^{p}")
    if not parts:
        return "E"
    return " * ".join(parts)

# ---------- メイン処理 ----------
def compute_generators_with_words(symbol, user_gen_labels=None):
    if symbol not in PG_HM:
        raise ValueError(f"Unknown point group: {symbol}")
    pg = PointGroup(symbol)
    all_ops_raw = [orthogonalize_preserving_det(op.rotation_matrix) for op in pg.symmetry_ops]

    # 重複除去
    seen, all_ops = set(), []
    for R in all_ops_raw:
        k = canon(R)
        if k not in seen:
            seen.add(k)
            all_ops.append(R)

    gens, snap = find_generators(all_ops)

    # デフォルトラベル（自動推定）
    auto_labels = [default_label_of(G) for G in gens]

    # ユーザ指定があれば上書き
    gen_labels = auto_labels
    if user_gen_labels:
        gen_labels = []
        for i, G in enumerate(gens):
            gen_labels.append(user_gen_labels.get(i, auto_labels[i]))

    # 位数推定（指数正規化に使用）
    gen_orders = estimate_orders(gens, snap)

    # 語付き閉包
    mats, words_idx, key_to_idx = closure_with_words(gens, target_size=len(all_ops), snap=snap)
    ok = compare_sets(mats, all_ops)

    # 各要素（恒等元を除く）を整形
    elems = []
    for R in all_ops:
        k = canon(R)
        idx = key_to_idx.get(k, None)
        if idx is None:
            continue
        if len(words_idx[idx]) == 0:
            continue  # E はスキップ
        word_str = format_word(words_idx[idx], gen_labels, gen_orders)
        elems.append((R, word_str, default_label_of(R)))
    return gens, gen_labels, gen_orders, all_ops, elems, ok

def to_jsonable(R):
    return [[float(f"{v:.10f}") for v in row] for row in R]

def parse_gen_labels(s):
    """
    "g1=C6(z),g2=m(⊥z)" のような文字列を {0:"C6(z)", 1:"m(⊥z)"} に変換
    """
    mapping = {}
    if not s:
        return mapping
    items = [x.strip() for x in s.split(',') if x.strip()]
    for it in items:
        if '=' not in it:
            continue
        k, v = it.split('=', 1)
        k = k.strip().lower()
        if k.startswith('g'):
            try:
                idx = int(k[1:]) - 1
            except Exception:
                continue
            mapping[idx] = v.strip()
    return mapping

def main():
    ap = argparse.ArgumentParser(
        description=(
            "Point-group generators with words using positive exponents (e.g., C3 = C6^2, C6^5). "
            "You can predefine generator symbols like 'g1=C6(z),g2=m(⊥z)'."
        ),
        formatter_class=argparse.RawTextHelpFormatter
    )
    ap.add_argument("--pg", "-p", required=True, help="Point group symbol (H–M), e.g. 6, 4mm, 6mm, -3m")
    ap.add_argument("--gen-labels", type=str, default="",
                    help="Comma-separated labels, e.g. 'g1=C6(z),g2=m(⊥z)'")
    ap.add_argument("--json", action="store_true", help="Also output JSON")
    args = ap.parse_args()

    user_labels = parse_gen_labels(args.gen_labels)

    try:
        gens, gen_labels, gen_orders, all_ops, elems, ok = compute_generators_with_words(
            args.pg, user_gen_labels=user_labels
        )
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

    print(f"--- Point Group {args.pg} ---")
    print(f"Total operations: {len(all_ops)}")
    print(f"Verification (⟨gens⟩ == group): {'OK' if ok else 'NG'}\n")

    # Generators
    print("Generators (preferred order):")
    for i, (lbl, G, n) in enumerate(zip(gen_labels, gens, gen_orders), 1):
        ord_txt = f"  (order={n})" if n is not None else ""
        print(f"  g{i} = {lbl}{ord_txt}")
        print(np.array2string(G, formatter={'float_kind': lambda x: f"{x:9.6f}"}))
        print()

    # Elements (exclude identity)
    print("Generated elements (excluding identity):")
    for i, (R, word, cls_lbl) in enumerate(elems, 1):
        print(f"  O{i:02d}: {word}   =>  {cls_lbl}")
        print(np.array2string(R, formatter={'float_kind': lambda x: f"{x:9.6f}"}))
        print()

    if args.json:
        out = {
            "point_group": args.pg,
            "total_ops": len(all_ops),
            "verified": bool(ok),
            "generators": [
                {"name": f"g{i+1}", "label": lbl, "order": gen_orders[i], "matrix": to_jsonable(G)}
                for i, (lbl, G) in enumerate(zip(gen_labels, gens))
            ],
            "elements": [
                {"name": f"O{i+1}", "word": word, "class": cls_lbl, "matrix": to_jsonable(R)}
                for i, (R, word, cls_lbl) in enumerate(elems)
            ],
        }
        print(json.dumps(out, ensure_ascii=False, indent=2))


if __name__ == "__main__":
    main()
    input("\nPress ENTER to terminate>>\n")
    