#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
point_group_inf.py  (CLI using tkpointgroup)
  - Pure Python (NumPy) point-group operations and orbits tool.
  - Input: Schoenflies (Cnv, Dnh, Td, Oh, ...) or representative
           international symbols (4/mmm, -3m, m-3m, ...).
  - Modes:
      * No --p given: print generators or full group operations (3x3).
      * With --p: compute symmetry orbits for given points; optionally
                  compress to independent representatives and list ops that generate them.

# 対応一覧
python point_group_inf.py --list

# C3v の全操作を表示（既定はテキスト）
python point_group_inf.py C3v

# D4h の generator だけ
python point_group_inf.py D4h --generators

# Td の全操作を JSON で取得
python point_group_inf.py Td --json > Td_ops.json

# 4/mmm の全操作を CSV で取得（Excel 等に取り込みやすい）
python point_group_inf.py 4/mmm --csv > 4mmm_ops.csv

# 表示の丸め桁数変更
python point_group_inf.py Oh --round 6

# 1) C3v, 入力点 (1,0,0) の軌道（全等価点と生成操作）
python point_group_inf.py C3v --p 1,0,0

# 2) D4h, 2点の「独立代表点」のみ（どの対称操作で得られるか付き）
python point_group_inf.py D4h --p 1,0,0 --p 0.2,0.3,0.4 --independent-only

# 3) 4/mmm, JSON で受け取り、他ツールに渡す
python point_group_inf.py 4/mmm --p 0,0,1 --independent-only --json > reps.json

# 4) -3m（D3d）で CSV（Excelに読み込み）
python point_group_inf.py D3d --p 0.1,0.2,0.3 --csv

"""

import json
import csv
import sys
import argparse
import numpy as np

import tkpointgroup as pg


# ========= I/O 補助 =========
def parse_point(arg: str) -> np.ndarray:
    try:
        x,y,z = [float(s) for s in arg.split(",")]
    except Exception:
        raise argparse.ArgumentTypeError(f"Point must be x,y,z (got '{arg}')")
    return np.array([x,y,z], float)

def fmt_vec(v, prec=6):
    return "(" + ", ".join(f"{x:.{prec}f}" for x in v) + ")"

def print_matrix(M, prec=3):
    fmt = f"{{: .{prec}f}}"
    rows = ["[" + " ".join(fmt.format(v) for v in row) + "]" for row in M]
    return "\n".join(rows)


# ========= CLI =========
def main():
    ap = argparse.ArgumentParser(description="Point-group operations & symmetry orbits (pure NumPy).")
    ap.add_argument("symbol", nargs="?", help="Point group symbol (e.g., C3v, D4h, Td, Oh, 4/mmm, -3m, m-3m)")
    ap.add_argument("--generators", action="store_true", help="Show generators only (no --p).")
    ap.add_argument("--json", action="store_true", help="Output as JSON (machine readable).")
    ap.add_argument("--csv", action="store_true", help="Output as CSV (machine readable).")
    ap.add_argument("--round", type=int, default=3, help="Display precision for text output (default: 3).")
    ap.add_argument("--list", action="store_true", help="List supported point-group symbols and exit.")
    # Orbits options
    ap.add_argument("--p", dest="points", action="append", type=parse_point,
                    help="Point x,y,z (repeatable). Example: --p 1,0,0 --p 0.2,0.3,0.4")
    ap.add_argument("--independent-only", action="store_true",
                    help="With --p: show only independent representatives.")
    ap.add_argument("--tol", type=float, default=1e-8, help="Tolerance for point equivalence (default 1e-8).")
    # 記法相互変換（簡易）
    ap.add_argument("--to-hm", action="store_true", help="Convert given symbol (Schoenflies) to Herman–Mauguin.")
    ap.add_argument("--to-s",  action="store_true", help="Convert given symbol (Herman–Mauguin) to Schoenflies.")
    args = ap.parse_args()

    if args.list:
        print("# Supported point groups:")
        print(", ".join(pg.supported_symbols()))
        return

    if not args.symbol:
        ap.error("symbol is required (try --list)")

    # 記法相互変換だけを行うショートパス
    if args.to_hm:
        print(pg.schoenflies_to_hm(args.symbol))
        return
    if args.to_s:
        print(pg.hm_to_schoenflies(args.symbol))
        return

    # ===== モード分岐 =====
    if not args.points:
        # --- 操作の列挙モード ---
        if args.generators:
            ops = pg.generators_for(args.symbol)
            if not ops:
                ops = pg.get_all_operations(args.symbol)
        else:
            ops = pg.get_all_operations(args.symbol)

        if args.json:
            payload = {
                "symbol": pg.normalize_symbol(args.symbol),
                "count": len(ops),
                "operations": [
                    {"label": lab, "matrix": np.asarray(M, float).round(10).tolist()}
                    for lab, M in ops
                ]
            }
            print(json.dumps(payload, ensure_ascii=False, indent=2))
            return

        if args.csv:
            w = csv.writer(sys.stdout)
            w.writerow(["symbol", pg.normalize_symbol(args.symbol)])
            w.writerow(["count", len(ops)])
            w.writerow(["label", "m11","m12","m13","m21","m22","m23","m31","m32","m33"])
            for lab, M in ops:
                flat = list(np.asarray(M, float).round(10).flatten())
                w.writerow([lab] + flat)
            return

        print(f"# Point group: {args.symbol}  (operations: {len(ops)})")
        print(f"# Mode: {'Generators' if args.generators else 'All operations'}")
        for i, (lab, M) in enumerate(sorted(ops, key=lambda t: (t[0], pg.mat_key(t[1]))), 1):
            print(f"\n[{i:02d}] {lab}")
            print(print_matrix(M, prec=args.round))
        return

    # --- 軌道/代表点モード ---
    mats = pg.build_group(args.symbol)
    ops = pg.label_elements(mats)  # [(label, M)]
    labels = [lab for lab,_ in ops]
    Ms = [M for _,M in ops]

    all_items = []
    for idx, p in enumerate(args.points):
        orbit = pg.orbit_for_point(Ms, p, tol=args.tol)  # [(q, [op_ids])]
        if args.independent_only:
            reps = pg.dedup_points(orbit, tol=args.tol)
            all_items.append({
                "input_index": idx,
                "input_point": p,
                "representatives": [
                    {"point": q, "op_indices": op_ids, "ops": [labels[i] for i in op_ids]}
                    for (q, op_ids) in reps
                ]
            })
        else:
            entries = []
            for q, op_ids in orbit:
                entries.append({
                    "point": q,
                    "op_indices": op_ids,
                    "ops": [labels[i] for i in op_ids]
                })
            all_items.append({
                "input_index": idx,
                "input_point": p,
                "orbit": entries
            })

    # 出力
    if args.json:
        if args.independent_only:
            payload = {
                "symbol": args.symbol,
                "ops_count": len(ops),
                "inputs": [
                    {
                        "input_index": it["input_index"],
                        "input_point": np.asarray(it["input_point"], float).tolist(),
                        "representatives": [
                            {"point": np.asarray(rep["point"], float).tolist(),
                             "op_indices": rep["op_indices"],
                             "ops": rep["ops"]}
                            for rep in it["representatives"]
                        ]
                    } for it in all_items
                ]
            }
        else:
            payload = {
                "symbol": args.symbol,
                "ops_count": len(ops),
                "inputs": [
                    {
                        "input_index": it["input_index"],
                        "input_point": np.asarray(it["input_point"], float).tolist(),
                        "orbit": [
                            {"point": np.asarray(ent["point"], float).tolist(),
                             "op_indices": ent["op_indices"],
                             "ops": ent["ops"]}
                            for ent in it["orbit"]
                        ]
                    } for it in all_items
                ]
            }
        print(json.dumps(payload, ensure_ascii=False, indent=2))
        return

    if args.csv:
        w = csv.writer(sys.stdout)
        if args.independent_only:
            w.writerow(["symbol", args.symbol])
            w.writerow(["ops_count", len(ops)])
            w.writerow(["input_index","input_x","input_y","input_z","rep_x","rep_y","rep_z","op_indices","ops_labels"])
            for it in all_items:
                ip = it["input_point"]
                for rep in it["representatives"]:
                    q = rep["point"]
                    w.writerow([it["input_index"], ip[0], ip[1], ip[2],
                                q[0], q[1], q[2],
                                ";".join(map(str, rep["op_indices"])),
                                ";".join(rep["ops"])])
        else:
            w.writerow(["symbol", args.symbol])
            w.writerow(["ops_count", len(ops)])
            w.writerow(["input_index","input_x","input_y","input_z","orbit_x","orbit_y","orbit_z","op_indices","ops_labels"])
            for it in all_items:
                ip = it["input_point"]
                for ent in it["orbit"]:
                    q = ent["point"]
                    w.writerow([it["input_index"], ip[0], ip[1], ip[2],
                                q[0], q[1], q[2],
                                ";".join(map(str, ent["op_indices"])),
                                ";".join(ent["ops"])])
        return

    # human-readable
    print(f"# Point group: {args.symbol} (ops: {len(ops)})")
    for it in all_items:
        ip = it["input_point"]
        print(f"\nInput[{it['input_index']}]: {fmt_vec(ip, args.round)}")
        if args.independent_only:
            reps = it["representatives"]
            print(f"  Representatives: {len(reps)}")
            for j, rep in enumerate(reps, 1):
                q = rep["point"]
                print(f"   [{j:02d}] {fmt_vec(q, args.round)}  |  ops: {', '.join(rep['ops'])}")
        else:
            orb = it["orbit"]
            print(f"  Orbit size: {len(orb)}")
            for j, ent in enumerate(orb, 1):
                q = ent["point"]
                print(f"   [{j:02d}] {fmt_vec(q, args.round)}  |  ops: {', '.join(ent['ops'])}")


if __name__ == "__main__":
    main()
    input("\nPress ENTER to terminate>>\n")
    
