# find_point_group.py (refactored to use tkpointgroup.py)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
XYZ 分子から点群を推定し、対称操作を直交化→整列→ラベル付け（C3v/C4v/C6v/C2v/Dn/T/Td/Th/O/Oh）して表示。

主な機能
- XYZ 読み込み（pymatgen）
- （任意）質量中心へ平行移動、慣性主軸整列
- （任意）検出した最大次数の Cn 回転軸を z へ自動整列
- PointGroupAnalyzer で点群推定（Schoenflies 記号）
- 対称操作（回転行列）を SVD で直交化して分類・ラベル付け
- Cnv（n=3,4,6）について σv/σd の区別、C2v は σv(xz)/σv'(yz) 区別
- Dn（n=3,4,6）について Cn^k(z) と xy 面内の C2'(x系) を分類

変更点（tkpointgroup.py の活用）
- 代表値スナップ/直交化の最終調整に tkpointgroup.snap_matrix を使用
- 記号正規化に tkpointgroup.normalize_symbol を使用
- ラベリング失敗時のフォールバック（任意）に tkpointgroup.classify_label を利用可能
- そのほか、共通ユーティリティの再実装を削減

必要:
  pip install pymatgen

python find_point_group.py --xyz h2o.xyz --center --align --dump-ops --write-aligned h2o_aligned.xyz
python find_point_group.py --xyz nh3.xyz --center --align --symprec 2e-3 --angle-tol 8
python find_point_group.py --xyz ch4.xyz --center --align --dump-ops
python find_point_group.py --xyz h2o.xyz --center --align 
python find_point_group.py --xyz SF6.xyz --center --align 
python find_point_group.py --xyz C6H6.xyz --center --align 
# 推奨：オプションはそのまま（eigen-tol=1e-2, matrix-tol=1e-1）
"""


import argparse
import numpy as np
from pathlib import Path
from typing import List, Tuple, Optional

from pymatgen.core import Molecule
from pymatgen.core.periodic_table import Element
from pymatgen.symmetry.analyzer import PointGroupAnalyzer

import tkpointgroup as pg  # ← 追加：共通ライブラリ

# ====== 数値ユーティリティ（pg を用いつつ、このスクリプト特有の処理を保持） ======

def _orth_preserve_det(R: np.ndarray) -> np.ndarray:
    """SVDで直交化し、元のdetの符号を保った上で tkpointgroup.snap_matrix で代表化。"""
    U, _, Vt = np.linalg.svd(R)
    R_ = U @ Vt
    if np.sign(np.linalg.det(R_)) != np.sign(np.linalg.det(R)):
        U[:, -1] *= -1
        R_ = U @ Vt
    return pg.snap_matrix(R_)  # 最終スナップはライブラリを使用

def _vec_to_unit(v):
    n = np.linalg.norm(v)
    return v / n if n > 1e-15 else v

def rot_axis_angle(R: np.ndarray) -> Tuple[Optional[np.ndarray], float]:
    """回転行列 R の回転軸（Noneのときは不定）と回転角(0..π)を返す。"""
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-10):
        return None, 0.0
    vals, vecs = np.linalg.eig(R)
    idx = np.where(np.isclose(vals.real, 1.0, atol=1e-6))[0]
    if idx.size == 0:
        th = float(np.arccos(np.clip((np.trace(R)-1)/2, -1, 1)))
        return None, th
    axis = vecs[:, idx[0]].real
    axis /= (np.linalg.norm(axis) + 1e-15)
    th = float(np.arccos(np.clip((np.trace(R)-1)/2, -1, 1)))
    return axis, th

def mirror_normal(R: np.ndarray) -> Optional[np.ndarray]:
    """鏡映の法線ベクトルを返す（鏡映でない場合は None）。"""
    R = _orth_preserve_det(R)
    if not (np.linalg.det(R) < 0 and np.allclose(R @ R, np.eye(3), atol=1e-6)):
        return None
    vals, vecs = np.linalg.eig(R)
    idx = np.where(np.isclose(vals.real, -1.0, atol=1e-6))[0]
    if idx.size == 0:
        return None
    n = vecs[:, idx[0]].real
    n /= (np.linalg.norm(n) + 1e-15)
    return n

def pretty_matrix(R: np.ndarray) -> str:
    Rt = np.where(np.abs(R) < 1e-10, 0.0, R)
    Rt = np.round(Rt, 6)
    return np.array2string(Rt, formatter={'float_kind': lambda x: f"{x:9.6f}"})

# 代表方向（規格化済み）
U100 = [_vec_to_unit(np.array([sx,0,0])) for sx in (1,-1)] + \
       [_vec_to_unit(np.array([0,sy,0])) for sy in (1,-1)] + \
       [_vec_to_unit(np.array([0,0,sz])) for sz in (1,-1)]

U110: List[np.ndarray] = []
for sx in (1,-1):
    for sy in (1,-1):
        U110 += [_vec_to_unit(np.array([sx, sy, 0]))]
for sy in (1,-1):
    for sz in (1,-1):
        U110 += [_vec_to_unit(np.array([0, sy, sz]))]
for sx in (1,-1):
    for sz in (1,-1):
        U110 += [_vec_to_unit(np.array([sx, 0, sz]))]

U111 = [_vec_to_unit(np.array([sx, sy, sz]))
        for sx in (1,-1) for sy in (1,-1) for sz in (1,-1)]

def align_score(axis: np.ndarray, fam: List[np.ndarray]) -> float:
    axis = _vec_to_unit(axis)
    return max(abs(float(np.dot(axis, d))) for d in fam)

# ====== 幾何前処理（共通関数はこのスクリプトで保持） ======

def center_of_mass(mol: Molecule) -> np.ndarray:
    masses = np.array([float(Element(s.species_string).atomic_mass) for s in mol.sites])
    coords = np.array([s.coords for s in mol.sites])
    return (masses[:, None] * coords).sum(axis=0) / masses.sum()

def inertia_align_matrix(mol: Molecule) -> np.ndarray:
    masses = np.array([float(Element(s.species_string).atomic_mass) for s in mol.sites])
    coords = np.array([s.coords for s in mol.sites])
    com = center_of_mass(mol)
    r = coords - com
    I = np.zeros((3, 3))
    for m, v in zip(masses, r):
        x, y, z = v
        I += m * np.array([[y*y+z*z, -x*y,    -x*z],
                           [-x*y,     x*x+z*z, -y*z ],
                           [-x*z,     -y*z,    x*x+y*y]])
    _, vecs = np.linalg.eigh(I)
    R = vecs.copy()
    if np.linalg.det(R) < 0:
        R[:, 0] *= -1
    return R

def apply_rigid_transform(mol: Molecule, R: np.ndarray, t: np.ndarray) -> Molecule:
    coords = np.array([s.coords for s in mol.sites])
    new = (coords - t) @ R.T
    species = [s.species for s in mol.sites]
    return Molecule(species, new)

def write_xyz(mol: Molecule, path: Path, comment: str = ""):
    with open(path, "w", encoding="utf-8") as f:
        f.write(f"{len(mol)}\n{comment}\n")
        for site in mol.sites:
            el = site.species_string
            x, y, z = site.coords
            f.write(f"{el:2s} {x: .6f} {y: .6f} {z: .6f}\n")

# ====== 点群推定 ======

def guess_point_group(mol: Molecule, tolerance=0.03, eigen_tolerance=1e-2):
    pga = PointGroupAnalyzer(mol, tolerance=tolerance, eigen_tolerance=eigen_tolerance)
    sch = pga.sch_symbol  # e.g., Td, Oh, T, Th, C3v ...
    ops = pga.get_symmetry_operations()
    return sch, ops

# ====== ラベル付け（各ファミリー） ======

def label_c2v(R: np.ndarray) -> str:
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    nrm = mirror_normal(R)
    if nrm is not None:
        # x軸・y軸により近い方で σv/σv' を判定
        return "σv(xz)" if abs(nrm[1]) >= abs(nrm[0]) else "σv'(yz)"
    axis, th = rot_axis_angle(R)
    if axis is not None and abs(axis[2]) > 0.999 and np.isclose(th, np.pi, atol=1e-2):
        return "C2(z)"
    if np.isclose(th, np.pi, atol=1e-2): return "C2"
    return "?"

def label_cnv(n: int, R: np.ndarray) -> str:
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    nrm = mirror_normal(R)
    if nrm is not None:
        nxy = _vec_to_unit(np.array([nrm[0], nrm[1], 0.0]))
        v_dirs = [np.array([ 1, 0, 0.0]), np.array([ 0, 1, 0.0]),
                  np.array([-1, 0, 0.0]), np.array([ 0,-1, 0.0])]
        d_dirs = [np.array([ 1, 1, 0.0]), np.array([ 1,-1, 0.0]),
                  np.array([-1, 1, 0.0]), np.array([-1,-1, 0.0])]
        vscore = max(abs(np.dot(nxy, _vec_to_unit(v))) for v in v_dirs)
        dscore = max(abs(np.dot(nxy, _vec_to_unit(d))) for d in d_dirs)
        return "σv" if vscore >= dscore else "σd"
    det = np.linalg.det(R)
    if det > 0:
        axis, th = rot_axis_angle(R)
        if axis is not None and abs(axis[2]) > 0.999:
            k = int(round(th / (2*np.pi/n))) % n
            if n in (4,6) and k == n//2: return "C2(z)"
            if k == 0: return "E"
            return f"C{n}^{k}(z)"
        if np.isclose(th, np.pi, atol=1e-2): return "C2"
    return "?"

def label_dn(n: int, R: np.ndarray) -> str:
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    det = np.linalg.det(R)
    axis, th = rot_axis_angle(R) if det > 0 else (None, None)
    if det > 0 and axis is not None and abs(axis[2]) > 0.999:
        k = int(round(th / (2*np.pi/n))) % n
        if n % 2 == 0 and k == n//2: return "C2(z)"
        if k == 0: return "E"
        return f"C{n}^{k}(z)"
    if det > 0 and axis is not None and abs(axis[2]) < 1e-3 and np.isclose(th, np.pi, atol=1e-2):
        return "C2'(xy)"
    return "?"

def label_t(R: np.ndarray) -> str:
    """T(23): E, 8*C3(<111>), 3*C2(<110> or <100>)"""
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    if np.linalg.det(R) < 0: return "?"
    axis, th = rot_axis_angle(R)
    if axis is None: return "?"
    th_deg = np.degrees(th)
    if (np.isclose(th_deg, 120.0, atol=1.0) or np.isclose(th_deg, 240.0, atol=1.0)) and align_score(axis, U111) > 0.99:
        return "C3" if th_deg < 180.0 else "C3^2"
    if np.isclose(th_deg, 180.0, atol=1.0) and (align_score(axis, U110) > 0.99 or align_score(axis, U100) > 0.99):
        return "C2"
    return "?"

def label_th(R: np.ndarray) -> str:
    """Th(m-3): T + inversion + improper."""
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    if np.allclose(R, -np.eye(3), atol=1e-8): return "i"
    det = np.linalg.det(R)
    if det > 0:
        return label_t(R)
    Rt = -R
    axis, th = rot_axis_angle(Rt)
    if axis is None: return "?"
    th_deg = np.degrees(th)
    if align_score(axis, U111) > 0.99:
        if np.isclose(th_deg, 60.0, atol=1.0):  return "S6"
        if np.isclose(th_deg, 300.0, atol=1.0): return "S6^5"
    if np.isclose(th_deg, 180.0, atol=1.0) and (align_score(axis, U110) > 0.99 or align_score(axis, U100) > 0.99):
        return "σh"
    return "?"

def label_td(R: np.ndarray) -> str:
    """Td: E, 8*C3(<111>), 3*C2(<100>/<110>), 6*S4(<100>), 6*σd(<110>)"""
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    det = np.linalg.det(R)
    nrm = mirror_normal(R)
    if nrm is not None:
        return "σd" if align_score(nrm, U110) > 0.99 else "mirror"
    if det > 0:
        axis, th = rot_axis_angle(R)
        if axis is None: return "?"
        th_deg = np.degrees(th)
        if (np.isclose(th_deg, 120.0, atol=1.0) or np.isclose(th_deg, 240.0, atol=1.0)) and align_score(axis, U111) > 0.99:
            return "C3" if th_deg < 180.0 else "C3^2"
        if np.isclose(th_deg, 180.0, atol=1.0) and (align_score(axis, U100) > 0.99 or align_score(axis, U110) > 0.99):
            return "C2"
        return "?"
    Rt = -R
    axis, th = rot_axis_angle(Rt)
    if axis is not None and align_score(axis, U100) > 0.99:
        th_deg = np.degrees(th)
        if np.isclose(th_deg, 90.0, atol=1.0):  return "S4"
        if np.isclose(th_deg, 270.0, atol=1.0): return "S4^3"
    return "?"

def label_o(R: np.ndarray) -> str:
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    if np.linalg.det(R) < 0: return "?"
    axis, th = rot_axis_angle(R)
    if axis is None: return "?"
    th_deg = np.degrees(th)
    if (np.isclose(th_deg, 120.0, atol=1.0) or np.isclose(th_deg, 240.0, atol=1.0)) and align_score(axis, U111) > 0.99:
        return "C3" if th_deg < 180.0 else "C3^2"
    if align_score(axis, U100) > 0.99:
        if np.isclose(th_deg, 90.0, atol=1.0):  return "C4"
        if np.isclose(th_deg, 270.0, atol=1.0): return "C4^3"
        if np.isclose(th_deg, 180.0, atol=1.0): return "C2(⟨100⟩)"
    if np.isclose(th_deg, 180.0, atol=1.0) and align_score(axis, U110) > 0.99:
        return "C2(⟨110⟩)"
    return "?"

def label_oh(R: np.ndarray) -> str:
    R = _orth_preserve_det(R)
    if np.allclose(R, np.eye(3), atol=1e-8): return "E"
    if np.allclose(R, -np.eye(3), atol=1e-8): return "i"
    det = np.linalg.det(R)
    if det > 0:
        return label_o(R)
    Rt = -R
    axis, th = rot_axis_angle(Rt)
    if axis is None: return "?"
    th_deg = np.degrees(th)
    if align_score(axis, U100) > 0.99:
        if np.isclose(th_deg, 90.0, atol=1.0):  return "S4"
        if np.isclose(th_deg, 270.0, atol=1.0): return "S4^3"
    if align_score(axis, U111) > 0.99:
        if np.isclose(th_deg, 60.0, atol=1.0):   return "S6"
        if np.isclose(th_deg, 300.0, atol=1.0):  return "S6^5"
    return "?"

# ====== メイン ======

def main():
    ap = argparse.ArgumentParser(
        description="XYZ 分子から点群を推定し、対称操作を Cnv/Dn/T/Td/Th/O/Oh 記号でラベル付け。"
    )
    ap.add_argument("--xyz", required=True, help="入力 XYZ")
    ap.add_argument("--tolerance", type=float, default=0.03, help="座標許容差 [Å]")
    ap.add_argument("--eigen-tol", type=float, default=1e-2, help="固有値縮退の許容（無次元）")
    ap.add_argument("--center", action="store_true", help="質量中心へ平行移動")
    ap.add_argument("--align", action="store_true", help="慣性主軸整列")
    ap.add_argument("--dump-ops", action="store_true", help="直交化前の行列も表示")
    ap.add_argument("--write-aligned", metavar="OUT_XYZ", help="整列後 XYZ を保存")
    ap.add_argument("--assume",
        choices=["C2v","C3v","C4v","C6v","D3","D4","D6","T","Th","Td","O","Oh"],
        help="分類ファミリー（未指定なら推定記号に基づいて自動）")
    args = ap.parse_args()

    mol = Molecule.from_file(args.xyz)

    if args.center or args.align:
        com = center_of_mass(mol)
        R0 = np.eye(3)
        if args.align:
            R0 = inertia_align_matrix(mol)
        mol = apply_rigid_transform(mol, R0, com)

    sch_raw, ops0 = guess_point_group(mol, tolerance=args.tolerance, eigen_tolerance=args.eigen_tol)
    sch = pg.normalize_symbol(sch_raw)  # ← 記号を正規化
    # Pymatgen の操作行列 → 直交化・代表スナップ
    ops = [_orth_preserve_det(op.rotation_matrix) for op in ops0]

    # ファミリー自動選択
    assume = args.assume
    if assume is None:
        if sch in {"Td"}:          assume = "Td"
        elif sch in {"Oh"}:        assume = "Oh"
        elif sch in {"O"}:         assume = "O"
        elif sch in {"T"}:         assume = "T"
        elif sch in {"Th"}:        assume = "Th"
        elif "v" in sch:           assume = sch if sch in {"C3v","C4v","C6v","C2v"} else "C2v"
        elif sch.startswith("D"):  assume = sch if sch in {"D3","D4","D6"} else "D3"
        else:
            if "3" in sch:   assume = "C3v"
            elif "4" in sch: assume = "C4v"
            elif "6" in sch: assume = "C6v"
            else:            assume = "C2v"

    # ラベリング
    labeled = []
    for R in ops:
        if assume == "Td":        name = label_td(R)
        elif assume == "O":       name = label_o(R)
        elif assume == "Oh":      name = label_oh(R)
        elif assume == "T":       name = label_t(R)
        elif assume == "Th":      name = label_th(R)
        elif assume == "C2v":     name = label_c2v(R)
        elif assume in ("C3v","C4v","C6v"):
            n = {"C3v":3,"C4v":4,"C6v":6}[assume]
            name = label_cnv(n, R)
        elif assume in ("D3","D4","D6"):
            n = int(assume[1:])
            name = label_dn(n, R)
        else:
            name = "?"
        # フォールバック（任意）：どうしても判別できないときに基本分類を添える
        if name == "?":
            name = pg.classify_label(R)
        labeled.append((name, R))

    # 出力
    print(f"File            : {args.xyz}")
    print(f"Atoms           : {len(mol)}")
    print(f"Guessed PG      : {sch_raw}")
    print(f"Assume family   : {assume}")
    print(f"Group order     : {len(ops)} operations")

    if args.dump_ops:
        print("\n--- Raw rotation matrices (from analyzer) ---")
        for i, op in enumerate(ops0, 1):
            print(f"[{i:02d}]")
            print(np.array2string(op.rotation_matrix, formatter={'float_kind': lambda x: f"{x:8.5f}"}))

    order_key = {
        "E": 0, "i": 0.5,
        # O/Oh
        "C3": 1, "C3^2": 2, "C4": 1.5, "C4^3": 1.6, "C2(⟨100⟩)": 2.5, "C2(⟨110⟩)": 2.6,
        "S4": 3.0, "S4^3": 3.1, "S6": 3.2, "S6^5": 3.3,
        # Td/T/Th 汎用
        "C2": 2.7, "σd": 4.0, "σh": 4.05, "mirror": 4.1,
        # Cnv/Dn
        "C2(z)": 1.2, "C3^1(z)": 1.3, "C3^2(z)": 1.4, "C4^1(z)": 1.5, "C4^2(z)": 1.6,
        "σv": 4.2, "σd(Cnv)": 4.21, "σv(xz)": 4.22, "σv'(yz)": 4.23, "C2'(xy)": 5,
        "?": 9
    }
    def sort_key(item):
        name, _ = item
        return (order_key.get(name, 8), name)

    print("\n--- Canonicalized & labeled operations ---")
    for i, (name, R) in enumerate(sorted(labeled, key=sort_key), 1):
        print(f"[{i:02d}] {name}\n{pretty_matrix(R)}\n")

    if args.write_aligned:
        out = Path(args.write_aligned)
        write_xyz(mol, out, comment=f"aligned; PG≈{sch_raw}; family={assume}")
        print(f"Aligned XYZ written to: {out}")

if __name__ == "__main__":
    main()
    input("\nPress ENTER to terminate>>\n")
