#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
python point_group.py list
python point_group.py pg --pg Td
python point_group.py expand --pg C3v --vec "0.1 0.2 0.3"
python point_group.py stereo --pg D2h
python point_group.py pg --pg Oh --no-anno

"""


import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

import tkpointgroup as tkg

# 32 種の結晶点群（Hermann–Mauguin）
PG_HM = [
    "1", "-1", "2", "m", "2/m", "222", "mm2", "mmm",
    "4", "-4", "4/m", "422", "4mm", "-42m", "4/mmm",
    "3", "-3", "32", "3m", "-3m",
    "6", "-6", "6/m", "622", "6mm", "-6m2", "6/mmm",
    "23", "m-3", "432", "-43m", "m-3m"
]

EPS = 1e-10


# ======== H–M → tkg.build_group が理解できる記号へのマップ ========
# tkg は多くの H–M を直接受け取れますが、足りないもの（-4, 2/m, 422, 32, 622, -6m2, -42m など）は
# Schoenflies に読み替えます。
HM_TO_SCHOENFLIES_FALLBACK = {
    "-4":   "S4",
    "-6":   "S6",
    "2/m":  "C2h",
    "422":  "D4",
    "32":   "D3",
    "622":  "D6",
    "-42m": "D2d",
    "-6m2": "D3h",
    "23":   "T",
    "m-3":  "Th",
    "432":  "O",
    "-43m": "Td",
    "m-3m": "Oh",
}

# point_group.py の先頭付近に追加

PG_ALIASES = {
    # C系
    "C1": "1",
    "Ci": "-1", "S2": "-1",
    "C2": "2",
    "Cs": "m",
    "C2h": "2/m",

    "C3": "3",
    "C3i": "-3", "S6": "-3",
    "C3v": "3m",
    "C3h": "3/m",

    "C4": "4",
    "S4": "4̅",   # "bar 4"
    "C4h": "4/m",
    "C4v": "4mm",
    "D2": "222",
    "D2h": "mmm",

    "C6": "6",
    "C6h": "6/m",
    "C6v": "6mm",
    "C3h": "3m1",  # 場合により "31m" と両方必要かも
    "D6h": "6/mmm",

    # T, O, I 系
    "T": "23",
    "Th": "m-3",
    "Td": "-43m",
    "O": "432",
    "Oh": "m-3m",

    "I": "532",
    "Ih": "m-3̅5"
}


def _pg_symbol_for_tkg(symbol: str) -> str:
    """H–M 記号を tkpointgroup が確実に受け取れる記号に変換して返す。"""
    s = tkg.normalize_symbol(symbol)
    # まずはそのまま試す（tkg は多くの H–M を直接受理可能）
    try:
        tkg.build_group(s)
        return s
    except Exception:
        pass
    # ダメならフォールバック変換（Schoenflies 等）
    s2 = HM_TO_SCHOENFLIES_FALLBACK.get(s, s)
    return s2


# --------- 数値安定化ユーティリティ ---------
def orthogonalize(R):
    """最も近い直交行列へ射影（SVD）。det の符号も整える。"""
    # tkg.snap_matrix が直交＆det=±1 を保証してくれるので、それを使う
    return tkg.snap_matrix(np.asarray(R, float))


def unique_dirs(vs, tol=1e-5, hemisphere=True):
    """
    方向ベクトルの重複排除。hemisphere=True のとき z>=0 に統一し、±同一視。
    """
    keys = set()
    uniq = []
    for v in vs:
        n = np.linalg.norm(v)
        if n < EPS:
            continue
        v = v / n
        if hemisphere and v[2] < 0:
            v = -v
        k = tuple(np.round(v, 5))
        if k not in keys:
            keys.add(k)
            uniq.append(v)
    return uniq


# --------- 対称操作の分類（点群用） ---------
def classify_symop_pg(R):
    """
    直交化済み R を受け取り、点群の操作を簡易分類。
    （ここは従来ロジックのまま。必要なら tkg.classify_label も使える）
    """
    R = orthogonalize(R)
    det = np.linalg.det(R)

    if np.allclose(R, np.eye(3), atol=1e-7):
        return "Identity"

    if np.allclose(R, -np.eye(3), atol=1e-7):
        return "Inversion"

    if np.isclose(det, 1.0, atol=1e-7):
        # 角度 θ は tr(R) = 1 + 2 cosθ
        cos_th = (np.trace(R) - 1) / 2
        cos_th = np.clip(cos_th, -1, 1)
        th = np.arccos(cos_th)
        if th < 1e-7:
            return "Identity"
        n = int(np.round(2 * np.pi / th))
        return f"Rotation C{n}"

    if np.isclose(det, -1.0, atol=1e-7):
        # 反射: R^2 = I（数値的には近い）
        if np.allclose(R @ R, np.eye(3), atol=1e-6):
            return "Mirror (σ)"
        # 回反（rotoinversion）
        cos_th = (np.trace(-R) - 1) / 2
        cos_th = np.clip(cos_th, -1, 1)
        th = np.arccos(cos_th)
        n = int(np.round(2 * np.pi / th))
        return f"Rotoinversion S{n}"

    return "Unknown"


# --------- 点群情報表示（tkg で取得） ---------
def display_pg_info(symbol=None):
    if symbol is None:
        print("--- Available Crystal Point Groups (H–M) ---")
        for ipg, pg in enumerate(PG_HM):
            print(f"  {ipg+1:2d}: {pg}")
        return

    if symbol not in PG_HM:
        print(f"Error: Unknown point group '{symbol}'.")
        sys.exit(1)

    s = _pg_symbol_for_tkg(symbol)
    try:
        ops = tkg.build_group(s)  # list of 3x3 np.ndarray
    except Exception as e:
        print(f"Error: cannot build group for '{symbol}' (mapped '{s}'): {e}")
        sys.exit(1)

    print(f"--- Point Group: {symbol}  (mapped as '{s}') ---")
    print(f"Total symmetry operations: {len(ops)}\n")
    for i, M in enumerate(ops, start=1):
        R = orthogonalize(M)
        kind = classify_symop_pg(R)  # or tkg.classify_label(R)
        print(f"Operation {i:02d}: {kind}")
        print(np.array2string(R, formatter={'float_kind': lambda x: f"{x:8.5f}"}))
        print()


# --------- ステレオ投影 ---------
def stereo_project(v):
    """
    単位球上の点 v=(x,y,z) をステレオ投影（南極から z=0 平面）。
    上半球 z>=0 を単位円内に写す: x' = x/(1+z), y' = y/(1+z)
    z ≈ -1 は無限遠になるので None を返す。
    """
    x, y, z = v
    denom = 1.0 + z
    if np.isclose(denom, 0.0, atol=1e-12):
        return None
    return x / denom, y / denom


def plot_pg_projection(symbol, seed_vec=None, annotate=True):
    """
    点群のステレオ投影図：
      - 回転軸（固有値1の固有ベクトル）
      - 鏡映面（大円）
      - seed_vec を群で展開した方向点（番号付き）
    """
    if symbol not in PG_HM:
        print(f"Error: Unknown point group '{symbol}'.")
        sys.exit(1)

    # デフォルト seed ベクトル
    if seed_vec is None:
        seed_vec = np.array([0.8, 0.1, 0.05], dtype=float)

    # 正規化し、上半球へ
    n = np.linalg.norm(seed_vec)
    if n < EPS:
        raise ValueError("--vec must be non-zero")
    seed_vec = seed_vec / n
    if seed_vec[2] < 0:
        seed_vec = -seed_vec

    # ★ tkg 経由で操作行列を取得
    s = _pg_symbol_for_tkg(symbol)
    try:
        ops = tkg.build_group(s)
    except Exception as e:
        print(f"Error: cannot build group for '{symbol}' (mapped '{s}'): {e}")
        sys.exit(1)

    rot_axes = []         # 回転軸（固有値 1 の固有ベクトル）
    mirror_normals = []   # 鏡映面法線（固有値 -1 の固有ベクトル）
    orbit = []            # 群で展開した方向

    # --- 対称要素の抽出・方向展開 ---
    for M in ops:
        R = orthogonalize(M)
        det = np.linalg.det(R)

        # 展開（方向作用）
        w = R @ seed_vec
        if np.linalg.norm(w) > EPS:
            w = w / np.linalg.norm(w)
            if w[2] < 0:
                w = -w
            orbit.append(w)

        # 鏡映面
        if np.isclose(det, -1.0, atol=1e-7) and np.allclose(R @ R, np.eye(3), atol=1e-6):
            vals, vecs = np.linalg.eig(R)
            idx = np.where(np.isclose(vals.real, -1.0, atol=1e-6))[0]
            if idx.size:
                nrm = vecs[:, idx[0]].real
                nrm /= (np.linalg.norm(nrm) + 1e-15)
                mirror_normals.append(nrm)
        # 回転軸
        elif np.isclose(det, 1.0, atol=1e-7) and not np.allclose(R, np.eye(3), atol=1e-7):
            vals, vecs = np.linalg.eig(R)
            idx = np.where(np.isclose(vals.real, 1.0, atol=1e-6))[0]
            if idx.size:
                axis = vecs[:, idx[0]].real
                axis /= (np.linalg.norm(axis) + 1e-15)
                rot_axes.append(axis)

    rot_axes = unique_dirs(rot_axes, tol=1e-5, hemisphere=True)
    mirror_normals = unique_dirs(mirror_normals, tol=1e-5, hemisphere=True)
    orbit = unique_dirs(orbit, tol=1e-6, hemisphere=True)

    # --- 図の準備 ---
    fig, ax = plt.subplots(figsize=(6, 6))

    # 単位円枠
    circle = plt.Circle((0, 0), 1.0, fill=False, linewidth=1.5)
    ax.add_artist(circle)

    # --- 鏡映面を大円として描画 ---
    for nrm in mirror_normals:
        # 平面 n·r = 0 と単位球の交線
        ref = np.array([1.0, 0.0, 0.0])
        if np.allclose(np.abs(np.dot(ref, nrm)), 1.0, atol=1e-6):
            ref = np.array([0.0, 1.0, 0.0])
        u = ref - np.dot(ref, nrm) * nrm
        u /= (np.linalg.norm(u) + 1e-15)
        v = np.cross(nrm, u)
        v /= (np.linalg.norm(v) + 1e-15)

        ts = np.linspace(0, 2*np.pi, 721)
        XY = []
        for t in ts:
            p = np.cos(t) * u + np.sin(t) * v
            # 上半球へ
            if p[2] < 0:
                p = -p
            sp = stereo_project(p)
            if sp is not None:
                XY.append(sp)
        XY = np.array(XY)
        ax.plot(XY[:, 0], XY[:, 1], linestyle='--', linewidth=1, alpha=0.8)

    # --- 回転軸をプロット（上半球へ統一） ---
    for axis in rot_axes:
        a = axis / (np.linalg.norm(axis) + 1e-15)
        if a[2] < 0:
            a = -a
        sp = stereo_project(a)
        if sp is not None:
            ax.scatter(*sp, marker='^', s=90, edgecolor='k', facecolor='white')

    # --- 展開軌道（群作用した方向） ---
    for i, w in enumerate(orbit, start=1):
        sp = stereo_project(w)
        if sp is None:
            continue
        ax.scatter(*sp, marker='o', s=60, edgecolor='k', facecolor='#1f77b4')
        if annotate:
            ax.text(sp[0], sp[1], f"{i}", ha='center', va='center',
                    color='white', fontsize=9, weight='bold')

    # 見た目
    ax.set_aspect('equal', 'box')
    ax.set_xlim(-1.05, 1.05)
    ax.set_ylim(-1.05, 1.05)
    ax.axis('off')
    ax.set_title(f"Stereographic Projection — Point Group {symbol}\n"
                 f"Seed direction (normalized): {tuple(np.round(seed_vec, 3))}")

    # 凡例
    legend_handles = [
        Line2D([], [], linestyle='--', color='C0', label='Mirror great circle'),
        Line2D([], [], marker='^', linestyle='None', markeredgecolor='k',
               markerfacecolor='white', markersize=9, label='Rotation axis (dir.)'),
        Line2D([], [], marker='o', linestyle='None', markeredgecolor='k',
               markerfacecolor='#1f77b4', markersize=8, label='Expanded directions')
    ]
    ax.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(1.25, 1.0))

    plt.tight_layout()
    plt.show(block=False)

    # 画面でもどこにあるか分かるよう、座標一覧を出力
    print(f"\n--- Expanded directions (hemisphere, unique) for PG {symbol} (mapped '{s}') from seed {tuple(np.round(seed_vec,3))} ---")
    for i, w in enumerate(orbit, start=1):
        print(f"{i:02d}: ({w[0]: .6f}, {w[1]: .6f}, {w[2]: .6f})")


# --------- 方向（座標）展開：CLI向け ---------
def expand_direction(symbol, vec):
    """
    点群の全操作を与えた方向ベクトル vec(3,) に作用させ、方向集合を返す。
    （ベクトルは正規化し、z>=0 に寄せ、±同一視で重複排除）
    """
    s = _pg_symbol_for_tkg(symbol)
    try:
        ops = tkg.build_group(s)
    except Exception as e:
        raise RuntimeError(f"cannot build group for '{symbol}' (mapped '{s}'): {e}")

    imgs = []
    v0 = np.array(vec, dtype=float)
    n = np.linalg.norm(v0)
    if n < EPS:
        raise ValueError("Zero vector is not allowed for direction expansion.")
    v0 /= n
    if v0[2] < 0:
        v0 = -v0

    for M in ops:
        R = orthogonalize(M)
        w = R @ v0
        n = np.linalg.norm(w)
        if n < EPS:
            continue
        w /= n
        if w[2] < 0:
            w = -w
        imgs.append(w)

    return unique_dirs(imgs, tol=1e-6, hemisphere=True)


# --------- CLI ---------
def main():
    parser = argparse.ArgumentParser(
        prog='pg_util.py',
        description="Crystal Point Group Utility (robust axes/planes & stereographic projection)",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument('mode', choices=['list', 'pg', 'stereo', 'expand'],
                        help=("list   : List 32 crystal point groups (H–M)\n"
                              "pg     : Print symmetry matrices & classification\n"
                              "stereo : Plot stereographic projection (axes, mirrors, expanded dirs)\n"
                              "expand : Print expanded directions numerically"))
    parser.add_argument('--pg', '-p', help="Point group (H–M), e.g. 4mm, -3m, m-3m")
    parser.add_argument('--vec', type=str, default='0.8,0.1,0.05',
                        help="Seed direction for 'stereo'/'expand' (comma-separated), default '0.8,0.1,0.05'")
    parser.add_argument('--no-anno', action='store_true',
                        help="Disable numbering labels for expanded directions")
    args = parser.parse_args()

    if args.mode == 'list':
        display_pg_info(None)
        return

    if args.pg is not None and args.pg in PG_ALIASES:
        args.pg = PG_ALIASES[args.pg]

    if args.pg is None or args.pg not in PG_HM:
        print("Error: Please specify --pg with a valid H–M symbol (e.g. 4mm, -3m, m-3m).")
        sys.exit(1)

    if args.mode == 'pg':
        display_pg_info(args.pg)
    elif args.mode == 'stereo':
        if ',' in args.vec:
            _aa = args.vec.split(',')
        elif ' ' in args.vec:
            _aa = args.vec.split(' ')
        else:
            _aa = []
        if len(_aa) != 3:
            print("Error: --vec must be like 'x,y,z' or 'x y z' [args.vec={args.vec}]")
            input("\nPress ENTER to terminate>>\n")
            sys.exit(1)

        vec = np.array(list(map(float, _aa)), dtype=float)
        plot_pg_projection(args.pg, seed_vec=vec, annotate=not args.no_anno)
        if 'matplotlib.pyplot' in sys.modules and plt.get_fignums():
            input("\nPress ENTER to close>>")
    elif args.mode == 'expand':
        if ',' in args.vec:
            _aa = args.vec.split(',')
        elif ' ' in args.vec:
            _aa = args.vec.split(' ')
        else:
            _aa = []
        if len(_aa) != 3:
            print("Error: --vec must be like 'x,y,z' or 'x y z' [args.vec={args.vec}]")
            input("\nPress ENTER to terminate>>\n")
            sys.exit(1)

        vec = np.array(list(map(float, _aa)), dtype=float)

        imgs = expand_direction(args.pg, vec)
        print(f"--- Expanded directions (hemisphere, unique) for PG {args.pg} from {tuple(np.round(vec/np.linalg.norm(vec),3))} ---")
        for i, v in enumerate(imgs, 1):
            print(f"{i:02d}: ({v[0]: .6f}, {v[1]: .6f}, {v[2]: .6f})")
    else:
        parser.print_help()


if __name__ == '__main__':
    main()
    input("\nPress ENTER to terminate>>\n")
