#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
点群の指標表を用いて分子の振動既約表現を計算するスクリプト。

概要:
    指定されたXYZファイルと点群シンボルに基づき、分子の振動モードの既約表現を決定します。

詳細説明:
    本スクリプトは、分子の対称性を表す点群（Schoenflies表記、例: C2v, D3d, Td, Oh, C3h, C4h, Ch, I, Ih）と
    分子の構造情報を含むXYZファイルを引数として受け取ります。
    `pymatgen` が利用可能な環境では、ヘルマン・モーガン表記を内部で使用することがありますが、
    それ以外の場合や特定の点群（I, Ih, C3hなど）については、既知のクラスサイズと一般位置の仮定に基づく
    「抽象モード」にフォールバックします。

    計算は、まず全自由度に対応する指標 (Γ_3N) を、固定された原子のカウント（サイトベース）または
    一般位置の仮定から構築します。その後、分子の並進運動 (Γ_T) と回転運動 (Γ_R) の指標を差し引き、
    純粋な振動モードに対応する指標 (Γ_v) を導出します。
    最終的に、このΓ_vを指標の直交性定理を用いて、各既約表現への分解係数（多重度）を決定します。

関連リンク:
    :doc:`vib_irreps_usage`

使用例:
  python vib_irreps.py --pg Td  --xyz CH4.xyz
  python vib_irreps.py --pg Oh  --xyz SF6.xyz
  python vib_irreps.py --pg D6h --xyz C6H6.xyz
  python vib_irreps.py --pg I   --xyz C60.xyz
"""


import sys
import argparse
import numpy as np
from typing import Dict, List
from collections import defaultdict

import tkpointgroup as pg  # ★ 共通ライブラリ

# ---------- XYZ loader ----------
def load_coords_from_xyz(path: str) -> np.ndarray:
    """
    指定されたXYZファイルから原子の座標を読み込みます。

    概要:
        XYZファイルから原子の座標をNumPy配列として読み込みます。

    詳細説明:
        ファイル内の空行やコメント行、または座標情報を含まない行をスキップし、
        数値データのみを解析してNumPy配列として返します。
        最初の数行はヘッダー（原子数、コメント行）であると想定し、
        適切な行から座標の読み込みを開始します。
        もし最初の行が整数でない場合（例えばコメント行で始まる場合）は、
        行の読み込み開始位置を自動的に調整します。

    :param path: str: 読み込むXYZファイルのパス。
    :returns: numpy.ndarray: 読み込まれた原子座標の配列 (N, 3)。ファイルが空の場合、空の配列を返します。
    """
    coords = []
    with open(path, "r", encoding="utf-8") as f:
        lines = [ln.strip() for ln in f if ln.strip()]
    start = 0
    try:
        int(lines[0].split()[0]); start = 2
    except Exception:
        start = 0
    for ln in lines[start:]:
        toks = ln.replace(",", " ").split()
        if len(toks) < 4:
            continue
        x, y, z = map(float, toks[-3:])
        coords.append([x, y, z])
    return np.array(coords, dtype=float)

# ---------- Main ----------
def main():
    """
    スクリプトの主処理を実行し、分子の振動既約表現を計算して表示します。

    概要:
        コマンドライン引数から点群とXYZファイルパスを受け取り、
        分子の振動モードの既約表現を計算して結果を表示します。

    詳細説明:
        1. コマンドライン引数（点群シンボル、XYZファイルパス、許容誤差）を解析します。
        2. 指定された点群シンボルを正規化し、サポートされている点群であることを確認します。
        3. XYZファイルから原子座標を読み込み、座標が解析できなかった場合はエラーで終了します。
        4. 点群の特性表と、点群が「抽象モード」（例: I, Ih, C3h など、操作行列が直接定義されていない高対称点群）
           であるかどうかに応じて、処理を分岐します。
        5. **通常モードの場合:**
            - 実際の対称操作行列 (`ops`) を取得します。
            - 各操作の分類 (`raw_labels`) とクラスへの集約 (`class_map`) を行います。
            - Γ_3N（全自由度）、Γ_T（並進）、Γ_R（回転）の指標を計算します。
            - これらの指標を各対称クラスごとに集約し、Γ_v（振動）の指標を導出します。
            - 最終的に、Γ_vを既約表現に分解します。
        6. **抽象モードの場合:**
            - 事前に定義されたクラスサイズ (`pg.ABSTRACT_CLASS_SIZES`) を使用して、
              Γ_3N、Γ_T、Γ_Rの指標（特にΓ_3Nは'E'クラスで3N、他は0）を計算します。
            - これらからΓ_v（振動）の指標を導出します。
            - Γ_vを既約表現に分解します。この際、クラスサイズと群の位数をオーバーライドして使用します。
        7. 計算された各指標のキャラクターと最終的な振動既約表現の結果を標準出力に整形して表示します。

    :returns: None: 計算結果を標準出力に表示します。
    """
    sup = ", ".join(sorted(pg.PG_CHAR_TABLES.keys()))
    ap = argparse.ArgumentParser(
        description="Vibrational irreps from XYZ and a point group (Schoenflies-like)."
    )
    ap.add_argument("--pg", required=True, help=f"Point group (Schoenflies-like). Supported: {sup}")
    ap.add_argument("--xyz", required=True, help="XYZ file path")
    ap.add_argument("--tol", type=float, default=1e-5, help="Tolerance (Å) for fixed atoms")
    args = ap.parse_args()

    symbol = pg.normalize_symbol(args.pg)
    if symbol not in pg.PG_CHAR_TABLES:
        print(f"Error: point group '{args.pg}' not supported. Supported: {sorted(pg.PG_CHAR_TABLES.keys())}")
        sys.exit(1)

    coords = load_coords_from_xyz(args.xyz)
    if coords.size == 0:
        print("Error: no coordinates parsed from XYZ.")
        sys.exit(1)

    tbl = pg.character_table(symbol)
    classes = tbl["classes"]
    abstract = symbol in pg.ABSTRACT_CLASS_SIZES

    if not abstract:
        # --- 通常モード: 実操作行列を使う ---
        ops = pg.group_ops(symbol)  # 3x3
        raw_labels = [pg.classify_op_for_table(symbol, R) for R in ops]
        classes, class_map = pg.class_aggregation(symbol, raw_labels)

        chi_3N_lab = pg.gamma_3N_characters(coords, ops, raw_labels, tol=args.tol)
        chi_T_lab  = pg.gamma_trans_characters(raw_labels, class_map)
        chi_R_lab  = pg.gamma_rot_characters(raw_labels, class_map)

        chi_3N = pg.reduce_to_classes(chi_3N_lab, classes, class_map)
        chi_T  = {c: chi_T_lab.get(c, 0.0) for c in classes}
        chi_R  = {c: chi_R_lab.get(c, 0.0) for c in classes}
        chi_v  = {c: chi_3N[c] - chi_T[c] - chi_R[c] for c in classes}

        mults = pg.decompose_irreps(chi_v, symbol)

    else:
        # --- 抽象モード: クラスサイズのみから構成（I, Ih, C3h） ---
        class_sizes = pg.ABSTRACT_CLASS_SIZES[symbol]
        raw_labels = []
        for cl, sz in class_sizes.items():
            raw_labels.extend([cl]*sz)
        # identity class_map
        chi_T_lab  = pg.gamma_trans_characters(raw_labels, {c:c for c in classes})
        chi_R_lab  = pg.gamma_rot_characters(raw_labels, {c:c for c in classes})
        chi_T  = {c: chi_T_lab.get(c, 0.0) for c in classes}
        chi_R  = {c: chi_R_lab.get(c, 0.0) for c in classes}
        chi_3N = {c: 0.0 for c in classes}; chi_3N["E"] = 3.0 * coords.shape[0]
        chi_v  = {c: chi_3N[c] - chi_T[c] - chi_R[c] for c in classes}

        mults = pg.decompose_irreps(chi_v, symbol,
                                    class_sizes_override=class_sizes,
                                    order_override=sum(class_sizes.values()))

    def row(d): return "  ".join(f"{cl:>7s}:{d.get(cl,0.0):>6.1f}" for cl in classes)

    print(f"Point group: {symbol}  (HM = {pg.schoenflies_to_hm(symbol)})  mode={'abstract' if abstract else 'normal'}")
    print("Classes   :", "  ".join(f"{cl:>7s}" for cl in classes))
    print("Γ_3N chars:", row(chi_3N))
    print("Γ_T  chars:", row(chi_T))
    print("Γ_R  chars:", row(chi_R))
    print("Γ_v  chars:", row(chi_v))
    print("\nVibrational irreps:", pg.pretty_irreps(symbol, mults))

if __name__ == "__main__":
    main()
    input("\nPress ENTER to terminate>>\n")