#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Vibrational irreps from a molecule (XYZ) using point-group character tables.

- Input point group in Schoenflies-like symbols (e.g., C2v, D3d, Td, Oh, C3h, C4h, Ch, I, Ih).
- Uses pymatgen (Hermann–Mauguin) when available; otherwise falls back to an abstract-mode
  that uses known class sizes and a general-position assumption for Γ_3N.
- Builds Γ_3N from fixed-atom counting (site-basis) or general-position assumption, subtracts
  translations/rotations, and decomposes to irreps with character orthogonality.

Usage:
  python vib_irreps.py --pg Td  --xyz CH4.xyz
  python vib_irreps.py --pg Oh  --xyz SF6.xyz
  python vib_irreps.py --pg D6h --xyz C6H6.xyz
  python vib_irreps.py --pg I   --xyz C60.xyz
"""


import sys
import argparse
import numpy as np
from typing import Dict, List
from collections import defaultdict

import tkpointgroup as pg  # ★ 共通ライブラリ

# ---------- XYZ loader ----------
def load_coords_from_xyz(path: str) -> np.ndarray:
    coords = []
    with open(path, "r", encoding="utf-8") as f:
        lines = [ln.strip() for ln in f if ln.strip()]
    start = 0
    try:
        int(lines[0].split()[0]); start = 2
    except Exception:
        start = 0
    for ln in lines[start:]:
        toks = ln.replace(",", " ").split()
        if len(toks) < 4:
            continue
        x, y, z = map(float, toks[-3:])
        coords.append([x, y, z])
    return np.array(coords, dtype=float)

# ---------- Main ----------
def main():
    sup = ", ".join(sorted(pg.PG_CHAR_TABLES.keys()))
    ap = argparse.ArgumentParser(
        description="Vibrational irreps from XYZ and a point group (Schoenflies-like)."
    )
    ap.add_argument("--pg", required=True, help=f"Point group (Schoenflies-like). Supported: {sup}")
    ap.add_argument("--xyz", required=True, help="XYZ file path")
    ap.add_argument("--tol", type=float, default=1e-5, help="Tolerance (Å) for fixed atoms")
    args = ap.parse_args()

    symbol = pg.normalize_symbol(args.pg)
    if symbol not in pg.PG_CHAR_TABLES:
        print(f"Error: point group '{args.pg}' not supported. Supported: {sorted(pg.PG_CHAR_TABLES.keys())}")
        sys.exit(1)

    coords = load_coords_from_xyz(args.xyz)
    if coords.size == 0:
        print("Error: no coordinates parsed from XYZ.")
        sys.exit(1)

    tbl = pg.character_table(symbol)
    classes = tbl["classes"]
    abstract = symbol in pg.ABSTRACT_CLASS_SIZES

    if not abstract:
        # --- 通常モード: 実操作行列を使う ---
        ops = pg.group_ops(symbol)  # 3x3
        raw_labels = [pg.classify_op_for_table(symbol, R) for R in ops]
        classes, class_map = pg.class_aggregation(symbol, raw_labels)

        chi_3N_lab = pg.gamma_3N_characters(coords, ops, raw_labels, tol=args.tol)
        chi_T_lab  = pg.gamma_trans_characters(raw_labels, class_map)
        chi_R_lab  = pg.gamma_rot_characters(raw_labels, class_map)

        chi_3N = pg.reduce_to_classes(chi_3N_lab, classes, class_map)
        chi_T  = {c: chi_T_lab.get(c, 0.0) for c in classes}
        chi_R  = {c: chi_R_lab.get(c, 0.0) for c in classes}
        chi_v  = {c: chi_3N[c] - chi_T[c] - chi_R[c] for c in classes}

        mults = pg.decompose_irreps(chi_v, symbol)

    else:
        # --- 抽象モード: クラスサイズのみから構成（I, Ih, C3h） ---
        class_sizes = pg.ABSTRACT_CLASS_SIZES[symbol]
        raw_labels = []
        for cl, sz in class_sizes.items():
            raw_labels.extend([cl]*sz)
        # identity class_map
        chi_T_lab  = pg.gamma_trans_characters(raw_labels, {c:c for c in classes})
        chi_R_lab  = pg.gamma_rot_characters(raw_labels, {c:c for c in classes})
        chi_T  = {c: chi_T_lab.get(c, 0.0) for c in classes}
        chi_R  = {c: chi_R_lab.get(c, 0.0) for c in classes}
        chi_3N = {c: 0.0 for c in classes}; chi_3N["E"] = 3.0 * coords.shape[0]
        chi_v  = {c: chi_3N[c] - chi_T[c] - chi_R[c] for c in classes}

        mults = pg.decompose_irreps(chi_v, symbol,
                                    class_sizes_override=class_sizes,
                                    order_override=sum(class_sizes.values()))

    def row(d): return "  ".join(f"{cl:>7s}:{d.get(cl,0.0):>6.1f}" for cl in classes)

    print(f"Point group: {symbol}  (HM = {pg.schoenflies_to_hm(symbol)})  mode={'abstract' if abstract else 'normal'}")
    print("Classes   :", "  ".join(f"{cl:>7s}" for cl in classes))
    print("Γ_3N chars:", row(chi_3N))
    print("Γ_T  chars:", row(chi_T))
    print("Γ_R  chars:", row(chi_R))
    print("Γ_v  chars:", row(chi_v))
    print("\nVibrational irreps:", pg.pretty_irreps(symbol, mults))

if __name__ == "__main__":
    main()
    input("\nPress ENTER to terminate>>\n")
    
