import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.symmetry.groups import SpaceGroup


# ----------------------------------------------------
# A. 空間群の情報を取得・表示する関数
# ----------------------------------------------------
def display_spg_info(ispg=None):
    """
    pymatgenを用いて空間群の情報を表示します。
    ispgが指定されなければ全空間群のリスト、
    指定されればその詳細を表示します。
    """
    if ispg is None:
        print("--- All Space Groups (1 to 230) ---")
        for i in range(1, 231):
            try:
                spg_group = SpaceGroup.from_int_number(i)
                print(f"[{i:3d}] {spg_group.symbol}")
            except Exception as e:
                print(f"[{i:3d}] (Could not retrieve info: {e})")
        print("\nUse -ispg [number] to get detailed info.")
    else:
        if not (1 <= ispg <= 230):
            print("Error: Invalid space group number. Must be between 1 and 230.")
            sys.exit(1)

        try:
            # ダミー構造を生成（立方格子、1つの C 原子を原点に配置）
            lattice = Lattice.cubic(1)
            dummy_structure = Structure.from_spacegroup(
                ispg, lattice, ["C"], [[0, 0, 0]]
            )
            sga = SpacegroupAnalyzer(dummy_structure)
        except Exception as e:
            print(f"Error: Could not retrieve data for space group {ispg}. {e}")
            sys.exit(1)

        print(f"--- Detailed Information for Space Group {ispg} ---")
        print(f"Hermann-Mauguin Symbol: {sga.get_space_group_symbol()}")
        print(f"Crystal System:          {sga.get_crystal_system()}")
        print(f"Point Group:             {sga.get_point_group_symbol()}")

        print("\n--- Symmetry Operations ---")
        symm_ops = sga.get_symmetry_operations()
        print(f"Total operations: {len(symm_ops)}\n")
        for i, op in enumerate(symm_ops[:5]):
            print(f"Operation {i + 1}:")
            print("  Rotation Matrix (R):")
            print(op.rotation_matrix)
            print("  Translation Vector (t):")
            print(op.translation_vector)
            print()
        if len(symm_ops) > 5:
            print("...")

# ----------------------------------------------------
# B. ステレオ投影図を描画する関数
# ----------------------------------------------------
def plot_unit_cell_projection(ispg):
    """
    空間群の一般位置を XY 単位格子にプロットし、
    ・z>0: 赤、z<0: 青 で散布
    ・xy 面に垂直なミラー面（x=const, y=const）を鎖線で描画
    ・z 軸回転軸を縦長楕円、▲、◆、六角形でプロット
    ・xy 面内の軸も同様にプロット
    """
    if not (1 <= ispg <= 230):
        print("Error: space group must be 1–230.")
        sys.exit(1)

    # 構造生成
    lattice = Lattice.cubic(1)
    dummy = Structure.from_spacegroup(ispg, lattice, ["X"], [[0.1, 0.2, 0.3]])
    sga = SpacegroupAnalyzer(dummy)
    sym_struct = sga.get_symmetrized_structure()
    # 一般位置（最大 orbit）
    general = max(sym_struct.equivalent_sites, key=len)
    coords = np.array([site.frac_coords for site in general])
    xs, ys, zs = coords[:,0], coords[:,1], coords[:,2]

    # ミラー面検出
    mirror_planes = []
    for op in sga.get_symmetry_operations():
        R = op.rotation_matrix
        t = op.translation_vector
        # x=const plane? 反射行列が diag(-1,1,1)
        if np.allclose(R, np.diag([-1,1,1])):
            mirror_planes.append(("x", t[0]/2))
        # y=const plane? diag(1,-1,1)
        if np.allclose(R, np.diag([1,-1,1])):
            mirror_planes.append(("y", t[1]/2))

    # 回転軸検出
    rot_axes = []
    for op in sga.get_symmetry_operations():
        R = op.rotation_matrix
        # proper rotation?
        if np.linalg.det(R) > 0 and not np.allclose(R, np.identity(3)):
            vals, vecs = np.linalg.eig(R)
            mask = np.isclose(vals, 1.0)
            if np.any(mask):
                axis = vecs[:,mask][:,0].real
                # 一意化
                if axis[2] < 0: axis *= -1
                rot_axes.append(tuple(np.round(axis,5)))
    rot_axes = sorted(set(rot_axes))

    # プロット開始
    fig, ax = plt.subplots(figsize=(6,6))
    # 単位格子枠
    box = np.array([[0,0],[1,0],[1,1],[0,1],[0,0]])
    ax.plot(box[:,0], box[:,1], 'k-')

    # ミラー面を鎖線で
    for typ, val in mirror_planes:
        if typ == "x":
            ax.vlines(val, 0, 1, linestyles='--')
        else:
            ax.hlines(val, 0, 1, linestyles='--')

    # 一般位置を色分け散布
    colors = ['red' if z>0 else 'blue' for z in zs]
    ax.scatter(xs, ys, c=colors, s=50, edgecolors='k')

    # 回転軸を記号でプロット
    # z 軸方向のもの → 縦長楕円
    offset = 0.5  # プロット位置の例
    for axis in rot_axes:
        x0, y0, z0 = axis
        # z 軸回転軸
        if np.allclose([x0,y0], [0,0]):
            ellipse = mpatches.Ellipse((0.5,0.5), width=0.15, height=0.3,
                                       fill=False, linewidth=2)
            ax.add_patch(ellipse)
        # xy 面内の軸
        elif np.isclose(z0, 0):
            # まず、その軸の方向ベクトル (x0,y0) を正規化してオフセットに
            v = np.array([x0,y0])
            v = v / np.linalg.norm(v) * 0.3
            # ▲
            ax.scatter(0.5+v[0], 0.5+v[1], marker='^', s=100)
            # ◆
            ax.scatter(0.5-v[1], 0.5+v[0], marker='D', s=80)
            # 六角形
            ax.scatter(0.5+v[1], 0.5-v[0], marker='h', s=80)

    ax.set_aspect('equal', 'box')
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel('x (fractional)')
    ax.set_ylabel('y (fractional)')
    ax.set_title(f"Space Group {ispg} — {SpaceGroup.from_int_number(ispg).symbol}")
    plt.tight_layout()
    plt.show()


def plot_stereographic_projection(ispg):
    """
    指定された空間群の一般位置の等価点を
    2D単位格子上にプロットします。
    """
    if not (1 <= ispg <= 230):
        print("Error: Invalid space group number. Must be between 1 and 230.")
        sys.exit(1)

    # 空間群とダミー構造の準備
    try:
        lattice = Lattice.cubic(1)
        # 一点だけ(0.1,0.2,0.3) を置くことで一般位置を生成
        dummy_structure = Structure.from_spacegroup(
            ispg, lattice, ["X"], [[0.1, 0.2, 0.3]]
        )
        sga = SpacegroupAnalyzer(dummy_structure)
        spg_group = SpaceGroup.from_int_number(ispg)
        hm_symbol = spg_group.symbol
    except Exception as e:
        print(f"Error: Could not build structure for space group {ispg}: {e}")
        sys.exit(1)

    # 対称化構造から軌道（orbit）を取得
    sym_struct = sga.get_symmetrized_structure()
    orbits = sym_struct.equivalent_sites
    # もっとも大きな orbit が一般位置
    general_orbit = max(orbits, key=lambda orb: len(orb))
    # 各サイトの fractional coords を取り出し XY 成分だけ
    coords = np.array([site.frac_coords[:2] for site in general_orbit])

    # プロット
    fig, ax = plt.subplots(figsize=(6,6))
    # 単位格子の枠線
    box = np.array([[0,0],[1,0],[1,1],[0,1],[0,0]])
    ax.plot(box[:,0], box[:,1], 'k-')
    # 等価位置を散布
    ax.scatter(coords[:,0], coords[:,1], s=50)
    ax.set_aspect('equal', 'box')
    ax.set_xlim(0,1)
    ax.set_ylim(0,1)
    ax.set_xlabel('x (fractional)')
    ax.set_ylabel('y (fractional)')
    ax.set_title(f"Equivalent Positions in Unit Cell (Space Group {ispg} — {hm_symbol})")
    plt.tight_layout()
    plt.show()


# ----------------------------------------------------
# C. メインプログラム
# ----------------------------------------------------
def main():
    parser = argparse.ArgumentParser(
        description="Display space group info or plot stereographic projection.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument('mode', choices=['spg', 'stereo'],
                        help="'spg' : Display space group info.\n'stereo': Plot stereographic projection.")
    parser.add_argument('--ispg', type=int,
                        help="Space group number (1 to 230).")

    args = parser.parse_args()
    if args.mode == 'spg':
        display_spg_info(args.ispg)
    else:
        if args.ispg is None:
            print("Error: -ispg is required for 'stereo' mode.")
            sys.exit(1)
        plot_unit_cell_projection(args.ispg)
#        plot_stereographic_projection(args.ispg)

if __name__ == '__main__':
    main()
