"""
# 対応一覧
python point_group_inf.py --list

# C3v の全操作を表示（既定はテキスト）
python point_group_inf.py C3v

# D4h の generator だけ
python point_group_inf.py D4h --generators

# Td の全操作を JSON で取得
python point_group_inf.py Td --json > Td_ops.json

# 4/mmm の全操作を CSV で取得（Excel 等に取り込みやすい）
python point_group_inf.py 4/mmm --csv > 4mmm_ops.csv

# 表示の丸め桁数変更
python point_group_inf.py Oh --round 6

# 1) C3v, 入力点 (1,0,0) の軌道（全等価点と生成操作）
python point_group_inf.py C3v --p 1,0,0

# 2) D4h, 2点の「独立代表点」のみ（どの対称操作で得られるか付き）
python point_group_inf.py D4h --p 1,0,0 --p 0.2,0.3,0.4 --independent-only

# 3) 4/mmm, JSON で受け取り、他ツールに渡す
python point_group_inf.py 4/mmm --p 0,0,1 --independent-only --json > reps.json

# 4) -3m（D3d）で CSV（Excelに読み込み）
python point_group_inf.py D3d --p 0.1,0.2,0.3 --csv

"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
pointgroup_cli.py
  - Pure Python (NumPy) point-group operations and orbits tool.
  - Input: Schoenflies (Cnv, Dnh, Td, Oh, ...) or representative
           international symbols (4/mmm, -3m, m-3m, ...).
  - Modes:
      * No --p given: print generators or full group operations (3x3).
      * With --p: compute symmetry orbits for given points; optionally
                  compress to independent representatives and list ops that generate them.

Examples:
  # 一覧
  python pointgroup_cli.py --list

  # 点群の全操作
  python pointgroup_cli.py C3v

  # 生成元のみ
  python pointgroup_cli.py D4h --generators

  # JSON / CSV 出力
  python pointgroup_cli.py Td --json
  python pointgroup_cli.py 4/mmm --csv > 4mmm_ops.csv

  # 座標の軌道
  python pointgroup_cli.py C3v --p 1,0,0

  # 独立代表点のみ（どの操作で得るか付き）
  python pointgroup_cli.py D4h --p 1,0,0 --p 0.2,0.3,0.4 --independent-only

  # JSONで軌道/代表点を取得
  python pointgroup_cli.py -3m --p 0.1,0.2,0.3 --json
"""
import math
import json
import argparse
import numpy as np
from typing import List, Tuple

# ========= 基本ユーティリティ =========
def _norm(v):
    v = np.asarray(v, dtype=float)
    n = np.linalg.norm(v)
    if n == 0: raise ValueError("Zero-length vector")
    return v / n

def rot(axis, angle_deg):
    a = _norm(axis)
    th = math.radians(angle_deg)
    c, s = math.cos(th), math.sin(th)
    x, y, z = a
    R = np.array([
        [c + x*x*(1-c),     x*y*(1-c) - z*s, x*z*(1-c) + y*s],
        [y*x*(1-c) + z*s,   c + y*y*(1-c),   y*z*(1-c) - x*s],
        [z*x*(1-c) - y*s,   z*y*(1-c) + x*s, c + z*z*(1-c)]
    ], dtype=float)
    return R

def mirror(normal):
    n = _norm(normal).reshape(3,1)
    return np.eye(3) - 2.0*(n @ n.T)

def inversion():
    return -np.eye(3)

_SNAP_VALUES = [
    0.0, 0.5, -0.5,
    1.0/math.sqrt(2), -1.0/math.sqrt(2),
    math.sqrt(3)/2, -math.sqrt(3)/2,
    1.0, -1.0
]
def _closest(x):
    return min(_SNAP_VALUES, key=lambda v: abs(v - x))

def snap_matrix(M: np.ndarray) -> np.ndarray:
    """行列要素を代表値へスナップし、直交性/ det=±1 を保証。"""
    X = np.array([[ (_closest(v) if abs(_closest(v)-v) < 5e-8 else v)
                    for v in row ] for row in M ], dtype=float)
    if not np.allclose(X.T @ X, np.eye(3), atol=1e-6):
        U, _, Vt = np.linalg.svd(X)
        X = U @ Vt
        X = np.array([[ (_closest(v) if abs(_closest(v)-v) < 5e-8 else v)
                        for v in row ] for row in X ], dtype=float)
    d = np.linalg.det(X)
    if not (abs(abs(d)-1.0) < 1e-6):
        U, _, Vt = np.linalg.svd(X)
        sign = 1.0 if d >= 0 else -1.0
        X = U @ np.diag([1,1,sign]) @ Vt
    return X

def mat_key(M) -> tuple:
    return tuple(np.round(np.asarray(M, float).flatten(), 10))

def unique_closure(generators: List[np.ndarray]) -> List[np.ndarray]:
    """生成元から有限群の閉包を構成。"""
    gens = [snap_matrix(G) for G in generators]
    if not gens: return [np.eye(3)]
    els = {}
    queue = []
    def add(E):
        K = mat_key(E)
        if K not in els:
            els[K] = E
            queue.append(E)
    add(np.eye(3))
    for G in gens: add(G)
    while queue:
        A = queue.pop()
        for B in list(els.values()):
            add(snap_matrix(A @ B))
            add(snap_matrix(B @ A))
    return list(els.values())

# 標準軸・面
ex, ey, ez = np.array([1,0,0.]), np.array([0,1,0.]), np.array([0,0,1.])
def vertical_plane(phi_deg):
    φ = math.radians(phi_deg)
    return np.array([math.sin(φ), -math.cos(φ), 0.0])
def diagonal_plane(phi_deg):
    return vertical_plane(phi_deg + 90.0)

# ========= 点群ビルダー =========
def build_Cn(n):   return unique_closure([rot(ez, 360.0/n)])
def build_Cnv(n):  return unique_closure([rot(ez,360.0/n), mirror(vertical_plane(0.0))])
def build_Cnh(n):  return unique_closure([rot(ez,360.0/n), mirror(ez)])
def build_Sn(n):   return unique_closure([ mirror(ez) @ rot(ez,360.0/n) ])
def build_Dn(n):   return unique_closure([rot(ez,360.0/n), rot(ex,180.0)])
def build_Dnh(n):  return unique_closure([rot(ez,360.0/n), rot(ex,180.0), mirror(ez)])
def build_Dnd(n):  return unique_closure([rot(ez,360.0/n), rot(ex,180.0), mirror(diagonal_plane(0.0))])
def build_T():     return unique_closure([ rot([1,1,1],120.0), rot(ex,180.0) ])
def build_Th():    return unique_closure([ rot([1,1,1],120.0), rot(ex,180.0), inversion() ])
def build_Td():
    S4z = mirror(ez) @ rot(ez,90.0)
    return unique_closure([ rot([1,1,1],120.0), rot(ex,180.0), S4z ])
def build_O():
    mats = []
    from itertools import permutations, product
    for perm in permutations(range(3)):
        P = np.eye(3)[list(perm)]
        for signs in product([-1,1], repeat=3):
            S = np.diag(signs)
            M = S @ P
            if round(np.linalg.det(M)) == 1:
                mats.append(snap_matrix(M))
    return list({mat_key(M): M for M in mats}.values())
def build_Oh():
    base = build_O()
    inv = inversion()
    d = {}
    for M in base + [inv @ M for M in base]:
        d[mat_key(M)] = snap_matrix(M)
    return list(d.values())

def normalize_symbol(s: str) -> str:
    s = s.strip().replace(" ", "")
    s = s.replace("−","-").replace("–","-").replace("—","-").replace("_","")
    return s

SUPPORTED = [
    # Schoenflies families
    "C1","Ci","Cs",
    "C2","C3","C4","C6",
    "C2v","C3v","C4v","C6v",
    "C2h","C3h","C4h","C6h",
    "S2","S3","S4","S6",
    "D2","D3","D4","D6",
    "D2h","D3h","D4h","D6h",
    "D2d","D3d","D4d","D6d",
    "T","Th","Td",
    "O","Oh",
    # Representative international notations
    "2","m","-1",
    "mm2","mmm",
    "4","4/m","4mm","4/mmm",
    "3","-3","3m","-3m",
    "6","6/m","6mm","6/mmm",
    "23","m-3","432","-43m","m-3m",
]

def build_group(symbol: str) -> List[np.ndarray]:
    s = normalize_symbol(symbol)
    # Schoenflies
    if s in ("C1","1"): return [np.eye(3)]
    if s in ("Ci","-1","1bar"): return [np.eye(3), inversion()]
    if s in ("Cs","m"): return unique_closure([mirror(ez)])
    if s.startswith("C") and s.endswith("v") and s[1:-1].isdigit(): return build_Cnv(int(s[1:-1]))
    if s.startswith("C") and s.endswith("h") and s[1:-1].isdigit(): return build_Cnh(int(s[1:-1]))
    if s.startswith("C") and s[1:].isdigit(): return build_Cn(int(s[1:]))
    if s.startswith("S") and s[1:].isdigit(): return build_Sn(int(s[1:]))
    if s.startswith("D") and s.endswith("h") and s[1:-1].isdigit(): return build_Dnh(int(s[1:-1]))
    if s.startswith("D") and s.endswith("d") and s[1:-1].isdigit(): return build_Dnd(int(s[1:-1]))
    if s.startswith("D") and s[1:].isdigit(): return build_Dn(int(s[1:]))
    if s == "T": return build_T()
    if s == "Th": return build_Th()
    if s == "Td": return build_Td()
    if s == "O": return build_O()
    if s == "Oh": return build_Oh()
    # International reps
    if s == "2": return build_Cn(2)
    if s == "mm2": return build_Cnv(2)
    if s == "mmm": return build_Dnh(2)
    if s == "4": return build_Cn(4)
    if s == "4/m": return build_Cnh(4)
    if s == "4mm": return build_Cnv(4)
    if s == "4/mmm": return build_Dnh(4)
    if s == "3": return build_Cn(3)
    if s in ("-3","3bar"): return build_Cnh(3)    # 実用マップ
    if s == "3m": return build_Cnv(3)
    if s in ("-3m","3barm","D3d"): return build_Dnh(3)
    if s == "6": return build_Cn(6)
    if s == "6/m": return build_Cnh(6)
    if s == "6mm": return build_Cnv(6)
    if s == "6/mmm": return build_Dnh(6)
    if s == "23": return build_T()
    if s == "m-3": return build_Th()
    if s == "432": return build_O()
    if s == "-43m": return build_Td()
    if s == "m-3m": return build_Oh()
    raise ValueError(f"Unsupported / unknown point group symbol: {symbol}")

# ========= ラベリング（簡易） =========
def classify_label(M: np.ndarray) -> str:
    if np.allclose(M, np.eye(3), atol=1e-8): return "E"
    if np.allclose(M, -np.eye(3), atol=1e-8): return "i"
    det = np.linalg.det(M)
    # mirror?
    if abs(det + 1.0) < 1e-6 and np.allclose(M @ M, np.eye(3), atol=1e-6):
        # 法線 = 固有値 -1 の固有ベクトル
        w, v = np.linalg.eig(M)
        idx = np.argmin(np.abs(w + 1))
        n = np.real(v[:, idx]); n = _norm(n)
        if abs(np.dot(n, ez)) > 0.9: return "σ_h"
        if abs(np.dot(n, ex)) > 0.9 or abs(np.dot(n, ey)) > 0.9: return "σ_v"
        return "σ"
    # rotation / improper
    tr = np.trace(M)
    ang = math.degrees(math.acos(max(-1.0, min(1.0, (tr-1)/2))))
    if abs(det - 1.0) < 1e-6:
        # proper rotation
        ax = np.array([M[2,1]-M[1,2], M[0,2]-M[2,0], M[1,0]-M[0,1]])
        if np.linalg.norm(ax) > 1e-8:
            a = _norm(ax)
            if abs(np.dot(a, ez)) > 0.9: axis = "z"
            elif abs(np.dot(a, ex)) > 0.9: axis = "x"
            elif abs(np.dot(a, ey)) > 0.9: axis = "y"
            elif abs(np.dot(a, _norm([1,1,1]))) > 0.9: axis = "111"
            else: axis = "u"
        else:
            axis = "u"
        return f"C({axis},{int(round(ang))})"
    else:
        ax = np.array([M[2,1]-M[1,2], M[0,2]-M[2,0], M[1,0]-M[0,1]])
        axis = "z" if np.linalg.norm(ax) > 1e-8 and abs(_norm(ax)@ez) > 0.9 else "u"
        return f"S({axis},{int(round(ang))})"

def label_elements(mats: List[np.ndarray]) -> List[Tuple[str, np.ndarray]]:
    return [(classify_label(M), M) for M in mats]

def generators_for(symbol: str) -> List[Tuple[str, np.ndarray]]:
    """本CLIで用いる規約上の generator（ラベル付き）"""
    s = normalize_symbol(symbol)
    gens = []
    if s in ("C1","1"):
        gens = []
    elif s in ("Ci","-1","1bar"):
        gens = [("i", inversion())]
    elif s in ("Cs","m"):
        gens = [("σ_h", mirror(ez))]
    elif s.startswith("C") and s.endswith("v") and s[1:-1].isdigit():
        n = int(s[1:-1]); gens = [(f"C(z,{360//n})", rot(ez,360.0/n)), ("σ_v(φ=0)", mirror(vertical_plane(0.0)))]
    elif s.startswith("C") and s.endswith("h") and s[1:-1].isdigit():
        n = int(s[1:-1]); gens = [(f"C(z,{360//n})", rot(ez,360.0/n)), ("σ_h", mirror(ez))]
    elif s.startswith("C") and s[1:].isdigit():
        n = int(s[1:]); gens = [(f"C(z,{360//n})", rot(ez,360.0/n))]
    elif s.startswith("S") and s[1:].isdigit():
        n = int(s[1:]); gens = [(f"S(z,{360//n})", mirror(ez) @ rot(ez,360.0/n))]
    elif s.startswith("D") and s.endswith("h") and s[1:-1].isdigit():
        n = int(s[1:-1]); gens = [(f"C(z,{360//n})", rot(ez,360.0/n)), ("C2(x)", rot(ex,180.0)), ("σ_h", mirror(ez))]
    elif s.startswith("D") and s.endswith("d") and s[1:-1].isdigit():
        n = int(s[1:-1]); gens = [(f"C(z,{360//n})", rot(ez,360.0/n)), ("C2(x)", rot(ex,180.0)), ("σ_d", mirror(diagonal_plane(0.0)))]
    elif s.startswith("D") and s[1:].isdigit():
        n = int(s[1:]); gens = [(f"C(z,{360//n})", rot(ez,360.0/n)), ("C2(x)", rot(ex,180.0))]
    elif s == "T":
        gens = [("C(111,120)", rot([1,1,1],120.0)), ("C2(x)", rot(ex,180.0))]
    elif s == "Th":
        gens = [("C(111,120)", rot([1,1,1],120.0)), ("C2(x)", rot(ex,180.0)), ("i", inversion())]
    elif s == "Td":
        gens = [("C(111,120)", rot([1,1,1],120.0)), ("C2(x)", rot(ex,180.0)), ("S4(z)", mirror(ez) @ rot(ez,90.0))]
    elif s == "O":
        gens = [("C4(z)", rot(ez,90.0)), ("C3(111)", rot([1,1,1],120.0))]
    elif s == "Oh":
        gens = [("C4(z)", rot(ez,90.0)), ("C3(111)", rot([1,1,1],120.0)), ("i", inversion())]
    else:
        # 国際記号は build_group に委譲（ここでは規約 generator は空）
        return []
    return [(lab, snap_matrix(M)) for lab,M in gens]

# ========= 軌道・独立代表点 =========
def point_key(p: np.ndarray, tol=1e-8):
    """同一性判定用キー（トレランスつき）。"""
    p = np.asarray(p, float)
    scale = max(tol, 1e-15)
    return tuple(np.round(p / scale).astype(int))

def orbit_for_point(ops: List[np.ndarray], p: np.ndarray, tol=1e-8):
    """一点 p からの軌道（等価点と、それを作る操作のインデックス）"""
    pts = {}
    for i, M in enumerate(ops):
        q = M @ p
        k = point_key(q, tol)
        if k not in pts:
            pts[k] = (q, [i])
        else:
            pts[k][1].append(i)
    return list(pts.values())  # [(q, [op_idx, ...]), ...]

def dedup_points(orbits: List[Tuple[np.ndarray, List[int]]], tol=1e-8):
    """(q, op_ids) のリストからユニーク点を抽出。"""
    seen = {}
    reps = []
    for q, op_ids in orbits:
        k = point_key(q, tol)
        if k not in seen:
            seen[k] = len(reps)
            reps.append([q, set(op_ids)])
        else:
            reps[seen[k]][1].update(op_ids)
    return [(np.asarray(q, float), sorted(list(ids))) for q, ids in reps]

# ========= I/O 補助 =========
def parse_point(arg: str) -> np.ndarray:
    try:
        x,y,z = [float(s) for s in arg.split(",")]
    except Exception as e:
        raise argparse.ArgumentTypeError(f"Point must be x,y,z (got '{arg}')")
    return np.array([x,y,z], float)

def fmt_vec(v, prec=6):
    return "(" + ", ".join(f"{x:.{prec}f}" for x in v) + ")"

def print_matrix(M, prec=3):
    fmt = f"{{: .{prec}f}}"
    rows = ["[" + " ".join(fmt.format(v) for v in row) + "]" for row in M]
    return "\n".join(rows)

# ========= CLI =========
def main():
    ap = argparse.ArgumentParser(description="Point-group operations & symmetry orbits (pure NumPy).")
    ap.add_argument("symbol", nargs="?", help="Point group symbol (e.g., C3v, D4h, Td, Oh, 4/mmm, -3m, m-3m)")
    ap.add_argument("--generators", action="store_true", help="Show generators only (no --p).")
    ap.add_argument("--json", action="store_true", help="Output as JSON (machine readable).")
    ap.add_argument("--csv", action="store_true", help="Output as CSV (machine readable).")
    ap.add_argument("--round", type=int, default=3, help="Display precision for text output (default: 3).")
    ap.add_argument("--list", action="store_true", help="List supported point-group symbols and exit.")
    # Orbits options
    ap.add_argument("--p", dest="points", action="append", type=parse_point,
                    help="Point x,y,z (repeatable). Example: --p 1,0,0 --p 0.2,0.3,0.4")
    ap.add_argument("--independent-only", action="store_true",
                    help="With --p: show only independent representatives.")
    ap.add_argument("--tol", type=float, default=1e-8, help="Tolerance for point equivalence (default 1e-8).")
    args = ap.parse_args()

    if args.list:
        print("# Supported point groups:")
        print(", ".join(SUPPORTED))
        return

    if not args.symbol:
        ap.error("symbol is required (try --list)")

    # ===== モード分岐 =====
    if not args.points:
        # --- 操作の列挙モード ---
        if args.generators:
            ops = generators_for(args.symbol)
            # 国際記号などで generator が空の場合は full ops へフォールバック
            if not ops:
                mats = build_group(args.symbol)
                ops = label_elements(mats)
        else:
            mats = build_group(args.symbol)
            ops = label_elements(mats)

        # 出力
        if args.json:
            payload = {
                "symbol": normalize_symbol(args.symbol),
                "count": len(ops),
                "operations": [
                    {"label": lab, "matrix": np.asarray(M, float).round(10).tolist()}
                    for lab, M in ops
                ]
            }
            print(json.dumps(payload, ensure_ascii=False, indent=2))
            return

        if args.csv:
            import csv, sys
            w = csv.writer(sys.stdout)
            w.writerow(["symbol", normalize_symbol(args.symbol)])
            w.writerow(["count", len(ops)])
            w.writerow(["label", "m11","m12","m13","m21","m22","m23","m31","m32","m33"])
            for lab, M in ops:
                flat = list(np.asarray(M, float).round(10).flatten())
                w.writerow([lab] + flat)
            return

        # human-readable
        print(f"# Point group: {args.symbol}  (operations: {len(ops)})")
        print(f"# Mode: {'Generators' if args.generators else 'All operations'}")
        for i, (lab, M) in enumerate(sorted(ops, key=lambda t: (t[0], mat_key(t[1]))), 1):
            print(f"\n[{i:02d}] {lab}")
            print(print_matrix(M, prec=args.round))
        return

    else:
        # --- 軌道/代表点モード ---
        mats = build_group(args.symbol)
        ops = label_elements(mats)  # [(label, M)]
        labels = [lab for lab,_ in ops]
        Ms = [M for _,M in ops]

        all_items = []
        for idx, p in enumerate(args.points):
            orbit = orbit_for_point(Ms, p, tol=args.tol)  # [(q, [op_ids])]
            if args.independent_only:
                reps = dedup_points(orbit, tol=args.tol)
                all_items.append({
                    "input_index": idx,
                    "input_point": p,
                    "representatives": [
                        {"point": q, "op_indices": op_ids, "ops": [labels[i] for i in op_ids]}
                        for (q, op_ids) in reps
                    ]
                })
            else:
                entries = []
                for q, op_ids in orbit:
                    entries.append({
                        "point": q,
                        "op_indices": op_ids,
                        "ops": [labels[i] for i in op_ids]
                    })
                all_items.append({
                    "input_index": idx,
                    "input_point": p,
                    "orbit": entries
                })

        # 出力
        if args.json:
            if args.independent_only:
                payload = {
                    "symbol": args.symbol,
                    "ops_count": len(ops),
                    "inputs": [
                        {
                            "input_index": it["input_index"],
                            "input_point": np.asarray(it["input_point"], float).tolist(),
                            "representatives": [
                                {"point": np.asarray(rep["point"], float).tolist(),
                                 "op_indices": rep["op_indices"],
                                 "ops": rep["ops"]}
                                for rep in it["representatives"]
                            ]
                        } for it in all_items
                    ]
                }
            else:
                payload = {
                    "symbol": args.symbol,
                    "ops_count": len(ops),
                    "inputs": [
                        {
                            "input_index": it["input_index"],
                            "input_point": np.asarray(it["input_point"], float).tolist(),
                            "orbit": [
                                {"point": np.asarray(ent["point"], float).tolist(),
                                 "op_indices": ent["op_indices"],
                                 "ops": ent["ops"]}
                                for ent in it["orbit"]
                            ]
                        } for it in all_items
                    ]
                }
            print(json.dumps(payload, ensure_ascii=False, indent=2))
            return

        if args.csv:
            import csv, sys
            w = csv.writer(sys.stdout)
            if args.independent_only:
                w.writerow(["symbol", args.symbol])
                w.writerow(["ops_count", len(ops)])
                w.writerow(["input_index","input_x","input_y","input_z","rep_x","rep_y","rep_z","op_indices","ops_labels"])
                for it in all_items:
                    ip = it["input_point"]
                    for rep in it["representatives"]:
                        q = rep["point"]
                        w.writerow([it["input_index"], ip[0], ip[1], ip[2],
                                    q[0], q[1], q[2],
                                    ";".join(map(str, rep["op_indices"])),
                                    ";".join(rep["ops"])])
            else:
                w.writerow(["symbol", args.symbol])
                w.writerow(["ops_count", len(ops)])
                w.writerow(["input_index","input_x","input_y","input_z","orbit_x","orbit_y","orbit_z","op_indices","ops_labels"])
                for it in all_items:
                    ip = it["input_point"]
                    for ent in it["orbit"]:
                        q = ent["point"]
                        w.writerow([it["input_index"], ip[0], ip[1], ip[2],
                                    q[0], q[1], q[2],
                                    ";".join(map(str, ent["op_indices"])),
                                    ";".join(ent["ops"])])
            return

        # human-readable
        print(f"# Point group: {args.symbol} (ops: {len(ops)})")
        for it in all_items:
            ip = it["input_point"]
            print(f"\nInput[{it['input_index']}]: {fmt_vec(ip, args.round)}")
            if args.independent_only:
                reps = it["representatives"]
                print(f"  Representatives: {len(reps)}")
                for j, rep in enumerate(reps, 1):
                    q = rep["point"]
                    print(f"   [{j:02d}] {fmt_vec(q, args.round)}  |  ops: {', '.join(rep['ops'])}")
            else:
                orb = it["orbit"]
                print(f"  Orbit size: {len(orb)}")
                for j, ent in enumerate(orb, 1):
                    q = ent["point"]
                    print(f"   [{j:02d}] {fmt_vec(q, args.round)}  |  ops: {', '.join(ent['ops'])}")

if __name__ == "__main__":
    main()
