import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt

from pymatgen.symmetry.groups import PointGroup

# 32種の結晶点群（Hermann–Mauguin 表記）
PG_HM = [
    "1", "-1", "2", "m", "2/m", "222", "mm2", "mmm",
    "4", "-4", "4/m", "422", "4mm", "-42m", "4/mmm",
    "3", "-3", "32", "3m", "-3m",
    "6", "-6", "6/m", "622", "6mm", "-6m2", "6/mmm",
    "23", "m-3", "432", "-43m", "m-3m"
]

def initialize():
    parser = argparse.ArgumentParser(
        prog='pg_util.py',
        usage=(
            "\n"
            "  pg_util.py --mode list\n"
            "  pg_util.py --mode inf --pg=<PointGroup> | --ipg=<Index>\n"
            "  pg_util.py --mode stereo --pg=<PointGroup> | --ipg=<Index> --xyz=<x,y,z>\n"
            "  pg_util.py --mode expand --pg=<PointGroup> | --ipg=<Index> --xyz=<x,y,z>"
        ),
        description="Crystal Point Group Utility",
        formatter_class=argparse.RawTextHelpFormatter
    )
    # mode引数をキーワード引数に変更
    parser.add_argument('--mode', choices=['list','inf','stereo','expand'],
                        help="'list'   : List all available point groups\n"
                             "'inf'    : Display point group info\n"
                             "'stereo' : Plot stereographic projection of expanded points\n"
                             "'expand' : Expand a point by symmetry operations")
    parser.add_argument('--pg', '-p',
                        help="Point group symbol (H–M), e.g. '4mm'")
    parser.add_argument('--ipg', '-i', type=int,
                        help="Index of the point group (0-31)")
    parser.add_argument('--xyz', '-x', default = '0.9,0.1,0.15',
                        help="Coordinates to expand, e.g., '0.9,0.1,0.15'")
    # rmin引数を削除
    return parser

def classify_symop(op):
    """
    SymmOp を受け取って、その対称操作を分類して文字列で返す。
    鏡映面と回転軸の方向も判定。
    """
    R = op.rotation_matrix
    det = np.linalg.det(R)

    # Identity
    if np.allclose(R, np.eye(3)):
        return "Identity"

    # Inversion center (R = -I)
    if np.allclose(R, -np.eye(3)):
        return "Inversion"

    # Proper rotation (det = +1, R ≠ I)
    if np.isclose(det, 1.0):
        cosθ = (np.trace(R) - 1) / 2
        θ = np.arccos(np.clip(cosθ, -1, 1))
        n = int(np.round(2 * np.pi / θ))
        
        # 回転軸の方向を判定
        vals, vecs = np.linalg.eig(R)
        axis_vec = vecs[:, np.isclose(vals, 1.0)].real[:, 0]
        axis_vec /= np.linalg.norm(axis_vec)
        
        axis_type = "d" # diagonal
        if np.allclose(np.abs(axis_vec), [1,0,0]): axis_type = "x"
        elif np.allclose(np.abs(axis_vec), [0,1,0]): axis_type = "y"
        elif np.allclose(np.abs(axis_vec), [0,0,1]): axis_type = "z"

        return f"Rotation C{n} ({axis_type}-axis)"

    # Improper (det = -1)
    if np.isclose(det, -1.0):
        # Reflection: R^2 = I
        if np.allclose(R @ R, np.eye(3)):
            # 鏡映面の法線ベクトルを判定
            vals, vecs = np.linalg.eig(R)
            nrm = vecs[:, np.isclose(vals, -1)].real[:, 0]
            nrm /= np.linalg.norm(nrm)

            plane_type = "d" # diagonal
            if np.allclose(np.abs(nrm), [1,0,0]): plane_type = "mx"
            elif np.allclose(np.abs(nrm), [0,1,0]): plane_type = "my"
            elif np.allclose(np.abs(nrm), [0,0,1]): plane_type = "mz"

            return f"Mirror (σ) ({plane_type}-plane)"
        
        # Rotoinversion: R ≠ ±I, det = -1
        cosθ = (np.trace(-R) - 1) / 2
        θ = np.arccos(np.clip(cosθ, -1, 1))
        n = int(np.round(2 * np.pi / θ))
        return f"Rotoinversion S{n}"

    return "Unknown"


def display_pg_info(symbol=None, print_level = 1):
    """
    結晶点群の情報表示。
    symbol=None → 点群一覧、
    symbol 指定 → 全操作の行列と分類を出力。
    """
    if symbol is None:
        print("--- Available Crystal Point Groups (H–M) ---")
        for ipg, pg in enumerate(PG_HM):
            print(f"  {ipg:2d}: {pg}")
        return

    if symbol not in PG_HM:
        print(f"Error: Unknown point group '{symbol}'.")
        sys.exit(1)

    pg = PointGroup(symbol)
    ops = pg.symmetry_ops

    op_inf = []
    if print_level:
        print(f"--- Point Group: {symbol} ---")
        print(f"Total symmetry operations: {len(ops)}\n")
    for i, op in enumerate(ops, start=1):
        kind = classify_symop(op)
        op_inf.append((op, kind))

        if print_level:
            print(f"Operation {i}: {kind}")
            print(op.rotation_matrix)   
            print()

    return op_inf

def plot_pg_projection(symbol, expanded_points):
    """
    展開された点のステレオ投影図を描画。
    z成分に基づいて色とラベルを決定。
    """
    fig, ax = plt.subplots(figsize=(6,6))

    circle = plt.Circle((0,0), 1, color='black', fill=False, linewidth=3)
    ax.add_artist(circle)
    
    points_by_xy = {}
    for p in expanded_points:
        x, y, z = p
        key = (round(x, 5), round(y, 5))
        if key not in points_by_xy:
            points_by_xy[key] = []
        points_by_xy[key].append(z)

    for (x, y), z_list in points_by_xy.items():
        has_pos_z = any(z > 1.0e-5 for z in z_list)
        has_neg_z = any(z < -1.0e-5 for z in z_list)
        has_zero_z = any(abs(z) < 1.0e-5 for z in z_list)
        
        marker_color = 'k'
        label_text = ''

        if has_zero_z:
            marker_color = 'gray'
            label_text = ''
        elif has_pos_z and has_neg_z:
            marker_color = 'yellow'
            label_text = '+-'
        elif has_pos_z:
            marker_color = 'white'
            label_text = '+'
        elif has_neg_z:
            marker_color = 'black'
            label_text = '-'
        
        ax.scatter(x, y, marker='o', s=200, facecolor=marker_color, edgecolor='black')
        ax.text(x, y, label_text, ha='center', va='center', color='k', fontsize=12)

    ax.set_aspect('equal')
    ax.set_xlim(-1.1,1.1)
    ax.set_ylim(-1.1,1.1)
    ax.axis('off')
    ax.set_title(f"Stereographic Projection of Point Group {symbol}")

    plt.savefig('pg_proj.png')
    plt.show(block=False) # ウィンドウを開いたままにする

def expand_point(symbol, xyz_str, print_level = 1):
    """
    点群の対称操作で点を展開し、pymatgenのデフォルト機能で重複を削除して出力
    """
    if symbol not in PG_HM:
        print(f"Error: Unknown point group '{symbol}'.")
        sys.exit(1)
    
    try:
        x, y, z = map(float, xyz_str.split(','))
        point = np.array([x, y, z])
    except (ValueError, TypeError):
        print("Error: Invalid format for --xyz. Please use 'x,y,z'.")
        sys.exit(1)

    pg = PointGroup(symbol)
    expanded_points = pg.get_orbit(point)
    expanded_points = [p.tolist() for p in expanded_points]

    if print_level:
        print(f"--- Point Group: {symbol}, Expanded points from ({x}, {y}, {z}) ---")
        print(f"Total unique expanded points: {len(expanded_points)}\n")
    for i, p in enumerate(expanded_points):
        if print_level:
            print(f"{i+1:02d}: ({p[0]:7.4f}, {p[1]:7.4f}, {p[2]:7.4f})")
    
    return expanded_points
    
def main():
    parser = initialize()
    
    # 引数がなく、かつ--helpでない場合にヘルプを表示
    if len(sys.argv) <= 1:
        parser.print_help()
        sys.exit(0)

    args = parser.parse_args()

    # ipgオプションが指定されたら、対応するpgを代入
    if args.ipg is not None:
        if 0 <= args.ipg < len(PG_HM):
            args.pg = PG_HM[args.ipg]
        else:
            print(f"Error: Invalid index for --ipg. Must be between 0 and {len(PG_HM)-1}.")
            sys.exit(1)
    else:
        args.ipg = PG_HM.index(args.pg)

    # mode引数が必須
    if args.mode is None:
        parser.print_help()
        sys.exit(1)

    print()
    print(f"mode: {args.mode}")
    print(f"Point group: #{args.ipg}  {args.pg}")
    print(f"xyz: {args.xyz}")

    if args.mode == 'list':
        display_pg_info(None)
    elif args.mode == 'inf':
        if args.pg is None:
            print(f"Error: --pg or --ipg is required for '{args.mode}' mode.")
            sys.exit(1)
        display_pg_info(args.pg)
    elif args.mode == 'stereo':
        if args.pg is None:
            print(f"Error: --pg or --ipg is required for '{args.mode}' mode.")
            sys.exit(1)
        expanded_points = expand_point(args.pg, args.xyz)
        plot_pg_projection(args.pg, expanded_points)
    elif args.mode == 'expand':
        if args.pg is None:
            print(f"Error: --pg or --ipg is required for '{args.mode}' mode.")
            sys.exit(1)
        expand_point(args.pg, args.xyz)

    if 'matplotlib.pyplot' in sys.modules and plt.get_fignums():
        input("\nPress ENTER to terminate>>")
        
if __name__ == '__main__':
    main()