import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt

from pymatgen.symmetry.groups import PointGroup  # 点群クラス:contentReference[oaicite:0]{index=0}

# 32種の結晶点群（Hermann–Mauguin 表記）
PG_HM = [
    "1", "-1", "2", "m", "2/m", "222", "mm2", "mmm",
    "4", "-4", "4/m", "422", "4mm", "-42m", "4/mmm",
    "3", "-3", "32", "3m", "-3m",
    "6", "-6", "6/m", "622", "6mm", "-6m2", "6/mmm",
    "23", "m-3", "432", "-43m", "m-3m"
]

def initialize():
    parser = argparse.ArgumentParser(
        prog='pg_util.py',
        usage=(
            "\n"
            "  pg_util.py pg --pg=<PointGroup>\n"
            "  pg_util.py pgproj --pg=<PointGroup>"
        ),
        description="Crystal Point Group Utility",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument('mode', choices=['pg','stereo'],
                        help="'pg' : Display point group info\n'pgproj': Plot stereographic projection")
    parser.add_argument('--pg', '-p',
                        help="Point group symbol (H–M), e.g. '4mm'")
    return parser

def classify_symop(op):
    """
    SymmOp を受け取って、その対称操作を分類して文字列で返す。
    """
    R = op.rotation_matrix
    det = np.linalg.det(R)

    # Identity
    if np.allclose(R, np.eye(3)):
        return "Identity"

    # Inversion center (R = -I)
    if np.allclose(R, -np.eye(3)):
        return "Inversion"

    # Proper rotation (det = +1, R ≠ I)
    if np.isclose(det, 1.0):
        # 角度 θ を tr(R)=1+2cosθ から逆算
        cosθ = (np.trace(R) - 1) / 2
        θ = np.arccos(np.clip(cosθ, -1, 1))
        # n = 2π/θ で何回回転か
        n = int(np.round(2 * np.pi / θ))
        return f"Rotation C{n}"

    # Improper (det = -1)
    if np.isclose(det, -1.0):
        # Reflection: R^2 = I
        if np.allclose(R @ R, np.eye(3)):
            return "Mirror (σ)"
        # Rotoinversion: R ≠ ±I, det = -1
        # 角度は R·(−I) の回転角から
        cosθ = (np.trace(-R) - 1) / 2
        θ = np.arccos(np.clip(cosθ, -1, 1))
        n = int(np.round(2 * np.pi / θ))
        return f"Rotoinversion S{n}"

    return "Unknown"


def display_pg_info(symbol=None):
    """
    結晶点群の情報表示。
    symbol=None → 点群一覧、
    symbol 指定 → 全操作の行列と分類を出力。
    """
    if symbol is None:
        print("--- Available Crystal Point Groups (H–M) ---")
        for ipg, pg in enumerate(PG_HM):
            print(f"  {ipg+1:2d}: {pg}")
        return

    if symbol not in PG_HM:
        print(f"Error: Unknown point group '{symbol}'.")
        sys.exit(1)

    pg = PointGroup(symbol)
    ops = pg.symmetry_ops

    print(f"--- Point Group: {symbol} ---")
    print(f"Total symmetry operations: {len(ops)}\n")
    for i, op in enumerate(ops, start=1):
        kind = classify_symop(op)
        print(f"Operation {i}: {kind}")
        print(op.rotation_matrix)
        print()
        
def plot_pg_projection(symbol):
    """
    点群のステレオ投影図を描画。
    ・回転軸は固有値解析で抽出し、ステレオ投影
    ・鏡映面は行列の判別で抽出して大円として描画
    """
    if symbol not in PG_HM:
        print(f"Error: Unknown point group '{symbol}'.")
        sys.exit(1)

    pg = PointGroup(symbol)
    ops = pg.symmetry_ops

    # 回転軸抽出
    rot_axes = []
    mirror_planes = []
    for op in ops:
        R = op.rotation_matrix
        # 鏡映面：det = -1, R^2 = I
        if np.isclose(np.linalg.det(R), -1) and np.allclose(R @ R, np.identity(3)):
            # plane normal = 固有ベクトル -1 に対応
            vals, vecs = np.linalg.eig(R)
            nrm = vecs[:, np.isclose(vals, -1)].real[:, 0]
            nrm = nrm / np.linalg.norm(nrm)
            mirror_planes.append(nrm)
        # 回転軸：det = +1, R ≠ I
        elif np.isclose(np.linalg.det(R), 1) and not np.allclose(R, np.identity(3)):
            vals, vecs = np.linalg.eig(R)
            if np.any(np.isclose(vals, 1.0)):
                axis = vecs[:, np.isclose(vals, 1.0)].real[:, 0]
                axis = axis / np.linalg.norm(axis)
                if axis[2] < 0: axis *= -1
                rot_axes.append(axis)

    # 重複排除
    def unique(vs):
        return list({tuple(np.round(v,5)): v for v in vs}.values())
    rot_axes = unique(rot_axes)
    mirror_planes = unique(mirror_planes)

    # ステレオ投影関数
    def stereo(v):
        x,y,z = v
        if np.isclose(z, -1.0): return None
        k = np.sqrt(2/(1+z))
        th = np.arctan2(y,x)
        return k*np.cos(th), k*np.sin(th)

    fig, ax = plt.subplots(figsize=(6,6))
    # 単位円
    circle = plt.Circle((0,0),1, fill=False)
    ax.add_artist(circle)

    # --- 鏡映面を大円（ここでは直線近似）としてプロット ---
    for n in mirror_planes:
        a, b, _ = n
        xs = np.linspace(-1, 1, 200)
    # b ≈ 0 のとき：a x = 0 → x = 0
        if abs(b) < 1e-6:
            ax.vlines(0, -1, 1, linestyles='--')
        # a ≈ 0 のとき：b y = 0 → y = 0
        elif abs(a) < 1e-6:
            ax.hlines(0, -1, 1, linestyles='--')
        else:
            ys = - (a / b) * xs
            ax.plot(xs, ys, linestyle='--')

    # 回転軸をマーカーでプロット（楕円形記号、▲、◆、hなどを使い分け）
    markers = ['o','^','D','h']
    for i, axis in enumerate(rot_axes):
        sp = stereo(axis)
        if sp is None: continue
        m = markers[i % len(markers)]
        ax.scatter(*sp, marker=m, s=100, label=f"Axis {tuple(np.round(axis,3))}")

    ax.set_aspect('equal')
    ax.set_xlim(-1.1,1.1); ax.set_ylim(-1.1,1.1)
    ax.axis('off')
    ax.legend(loc='upper right', bbox_to_anchor=(1.3,1))
    ax.set_title(f"Stereographic Projection of Point Group {symbol}")
    plt.show()

def main():
    parser = initialize()
    args = parser.parse_args()

    if args.mode == 'pg':
        display_pg_info(args.pg)
    else:
        if args.pg is None:
            print("Error: -pg is required for 'pgproj' mode.")
            sys.exit(1)
        plot_pg_projection(args.pg)

if __name__ == '__main__':
    main()
