import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.symmetry.groups import SpaceGroup
from pymatgen.core.operations import SymmOp


def initialize():
    parser = argparse.ArgumentParser(
        description="Display space group info or plot stereographic projection.",
        formatter_class=argparse.RawTextHelpFormatter
    )
    # mode引数をキーワード引数に変更し、infとlistを追加
    parser.add_argument('--mode', choices=['list', 'inf', 'stereo', 'expand'],
                        help="'list'   : List all space group numbers and symbols\n"
                             "'inf'    : Display detailed space group info\n"
                             "'stereo' : Plot stereographic projection of expanded points\n"
                             "'expand' : Expand coordinates and display results")
    parser.add_argument('--ispg', type=int,
                        help="Space group number (1 to 230).")
    # 初期値を変更
    parser.add_argument('--xyz', type=str, default='0.9,0.1,0.15',
                        help="Coordinates for expansion, e.g., '0.9,0.1,0.15'.")
    parser.add_argument('--rmin', type=float, default=1.0e-5,
                        help="Minimum distance for judging identical sites.")
    return parser

def classify_symop_spg(op, tol=1e-5):
    """
    SymmOp を受け取って、その対称操作を分類して文字列で返す。（空間群用）
    """
    R = op.rotation_matrix
    t = op.translation_vector
    det = np.linalg.det(R)
    
    is_pure_translation = np.allclose(t, [0,0,0], atol=tol)

    if np.allclose(R, np.eye(3), atol=tol):
        if is_pure_translation:
            return "Identity"
        else:
            return "Pure Translation"

    if np.allclose(R, -np.eye(3), atol=tol):
        return "Inversion"

    if np.isclose(det, 1.0, atol=tol):
        cosθ = (np.trace(R) - 1) / 2
        θ = np.arccos(np.clip(cosθ, -1, 1))
        n = int(np.round(2 * np.pi / θ))
        
        vals, vecs = np.linalg.eig(R)
        axis_vec = vecs[:, np.isclose(vals, 1.0, atol=tol)].real[:, 0]
        if np.linalg.norm(axis_vec) > tol:
            axis_vec /= np.linalg.norm(axis_vec)
        
        axis_type = "d"
        if np.allclose(np.abs(axis_vec), [1,0,0], atol=tol): axis_type = "x"
        elif np.allclose(np.abs(axis_vec), [0,1,0], atol=tol): axis_type = "y"
        elif np.allclose(np.abs(axis_vec), [0,0,1], atol=tol): axis_type = "z"

        if is_pure_translation:
            return f"Rotation C{n} ({axis_type}-axis)"
        else:
            return f"Screw Rotation C{n} ({axis_type}-axis)"

    if np.isclose(det, -1.0, atol=tol):
        if np.allclose(R @ R, np.eye(3), atol=tol):
            vals, vecs = np.linalg.eig(R)
            nrm = vecs[:, np.isclose(vals, -1.0, atol=tol)].real[:, 0]
            if np.linalg.norm(nrm) > tol:
                nrm /= np.linalg.norm(nrm)

            plane_type = "d"
            if np.allclose(np.abs(nrm), [1,0,0], atol=tol): plane_type = "mx"
            elif np.allclose(np.abs(nrm), [0,1,0], atol=tol): plane_type = "my"
            elif np.allclose(np.abs(nrm), [0,0,1], atol=tol): plane_type = "mz"

            if is_pure_translation:
                return f"Mirror (σ) ({plane_type}-plane)"
            else:
                return f"Glide Reflection (σ) ({plane_type}-plane)"

    return "Unknown"

def format_transform_spg(matrix, translation_vector):
    """
    回転行列と並進ベクトルを座標変換の数式としてフォーマットする
    """
    labels = ['x', 'y', 'z']
    equations = []
    
    for row_idx, row in enumerate(matrix):
        terms = []
        for col_idx, coeff in enumerate(row):
            if not np.isclose(coeff, 0):
                term_str = ""
                if np.isclose(coeff, 1):
                    term_str = labels[col_idx]
                elif np.isclose(coeff, -1):
                    term_str = f"-{labels[col_idx]}"
                else:
                    term_str = f"{coeff:.2f}{labels[col_idx]}"
                
                if terms and not term_str.startswith('-'):
                    terms.append(f"+{term_str}")
                else:
                    terms.append(term_str)
        
        t_comp = translation_vector[row_idx]
        if not np.isclose(t_comp, 0):
            t_str = f"{t_comp:.3f}"
            if terms and t_comp > 0:
                terms.append(f"+{t_str}")
            elif terms and t_comp < 0:
                terms.append(t_str)
            else:
                terms.append(t_str)

        equations.append("".join(terms) if terms else "0")
        
    return f"({equations[0]}, {equations[1]}, {equations[2]})"


def display_spg_info(ispg=None):
    """
    pymatgenを用いて空間群の情報を表示します。
    ispgが指定されなければ全空間群のリスト、
    指定されればその詳細を表示します。
    """
    if ispg is None:
        print("--- All Space Groups (1 to 230) ---")
        for i in range(1, 231):
            try:
                spg_group = SpaceGroup.from_int_number(i)
                print(f"[{i:3d}] {spg_group.symbol}")
            except Exception as e:
                print(f"[{i:3d}] (Could not retrieve info: {e})")
        return []

    if not (1 <= ispg <= 230):
        print("Error: Invalid space group number. Must be between 1 and 230.")
        sys.exit(1)

    try:
        lattice = Lattice.cubic(1)
        dummy_structure = Structure.from_spacegroup(
            ispg, lattice, ["C"], [[0, 0, 0]]
        )
        sga = SpacegroupAnalyzer(dummy_structure)
    except Exception as e:
        print(f"Error: Could not retrieve data for space group {ispg}. {e}")
        sys.exit(1)

    print(f"--- Detailed Information for Space Group {ispg} ---")
    print(f"Hermann-Mauguin Symbol: {sga.get_space_group_symbol()}")
    print(f"Crystal System:          {sga.get_crystal_system()}")
    print(f"Point Group:             {sga.get_point_group_symbol()}")

    print("\n--- Symmetry Operations ---")
    symm_ops = sga.get_symmetry_operations()
    print(f"Total operations: {len(symm_ops)}\n")
    
    op_inf = []
    for i, op in enumerate(symm_ops):
        kind = classify_symop_spg(op)
        op_inf.append((op, kind))
        
        print(f"Operation {i + 1}: {kind}")
        print("  Rotation Matrix (R):")
        matrix_str = "\n".join([f"    [{row[0]:7.4f}, {row[1]:7.4f}, {row[2]:7.4f}]" for row in op.rotation_matrix])
        print(matrix_str)
        print("  Translation Vector (t):")
        print(f"    [{op.translation_vector[0]:7.4f}, {op.translation_vector[1]:7.4f}, {op.translation_vector[2]:7.4f}]")
        
        transform_eq = format_transform_spg(op.rotation_matrix, op.translation_vector)
        print(f"  Transformation: {transform_eq}")
        print()
    return op_inf


def expand_coordinates(ispg, xyz_str, rmin):
    """
    指定された空間群(ITA標準設定)の対称操作で分率座標を展開し、重複を排除して表示
    """
    # 入力の解釈
    try:
        initial_point = np.array(list(map(float, xyz_str.split(','))), dtype=float)
        if initial_point.shape != (3,):
            raise ValueError
    except Exception:
        print("Error: Invalid format for --xyz. Please use 'x,y,z'.")
        sys.exit(1)

    # 空間群の対称操作（原点・設定のずれを避けるため SpaceGroup を直接使用）
    try:
        spg_group = SpaceGroup.from_int_number(ispg)
        symm_ops = spg_group.symmetry_ops  # List[SymmOp], すべて分率座標系での操作
    except Exception as e:
        print(f"Error: Could not retrieve symmetry operations for space group {ispg}. {e}")
        sys.exit(1)

    expanded_points = []
    print(f"--- Space Group {ispg} ({spg_group.symbol}), Expanded points from "
          f"({initial_point[0]:.4f}, {initial_point[1]:.4f}, {initial_point[2]:.4f}) ---")

    # 展開（分率座標 → 分率座標）
    for i, op in enumerate(symm_ops):
        transformed = op.operate(initial_point)  # 分率座標にそのまま適用
        # [0,1) に折り返し
        transformed = np.mod(transformed, 1.0)
        expanded_points.append(transformed)

        kind = classify_symop_spg(op)
        print(f"Operation {i + 1:02d} ({kind}): "
              f"({transformed[0]:7.4f}, {transformed[1]:7.4f}, {transformed[2]:7.4f})")

    # 重複排除（トーラス最短距離）
    unique_points = []
    for new_p in expanded_points:
        is_unique = True
        for old_p in unique_points:
            diff = np.abs(new_p - old_p)
            diff = np.where(diff > 0.5, 1.0 - diff, diff)
            if np.linalg.norm(diff) < rmin:
                is_unique = False
                break
        if is_unique:
            unique_points.append(new_p)

    print(f"\nTotal expanded points (including duplicates): {len(expanded_points)}")
    print(f"Total unique expanded points (rmin={rmin}): {len(unique_points)}")
    print("\n--- Unique Expanded Points ---")
    for i, p in enumerate(unique_points):
        print(f"{i + 1:02d}: ({p[0]:7.4f}, {p[1]:7.4f}, {p[2]:7.4f})")

    return unique_points


def plot_unit_cell_projection(ispg, expanded_points):
    """
    展開された座標をXY単位格子にプロットし、z値に応じて色とラベルを決定する
    """
    if not (1 <= ispg <= 230):
        print("Error: space group must be 1–230.")
        sys.exit(1)

    fig, ax = plt.subplots(figsize=(6,6))

    # 単位格子枠
    box = np.array([[0,0],[1,0],[1,1],[0,1],[0,0]])
    ax.plot(box[:,0], box[:,1], 'k-')

    # 座標のグループ化
    points_by_xy = {}
    for p in expanded_points:
        x, y, z = p
        key = (round(x, 5), round(y, 5))
        if key not in points_by_xy:
            points_by_xy[key] = []
        points_by_xy[key].append(z)

    # 各xy座標グループをプロット
    eps = 1.0e-5
    i = 0
    for (x, y), z_list in points_by_xy.items():
        z_set = sorted(set(round(z, 5) for z in z_list))
        # 代表値を 1つ選ぶ（例：最小値）。必要に応じて中央値などに変更可
        z0 = float(z_set[0])

        print(f"{i+1:03d}: (x,y)=({x:.5f},{y:.5f}) z_set={z_set} :", end="")

        has_z_eq_0 = (abs(z0) < eps or abs(z0 - 1.0) < eps)
        has_pair = False
        has_pair_p05 = False
        has_pair_m05 = False
        has_only_1_minus_z = False

        # 1−z, z±0.5 が同居しているかを z_set 全体で判定
        for _z in z_set:
            if abs(_z - z0) < eps or abs((_z + z0) - 1.0) < eps:
                if abs(_z - z0) > eps and abs((_z + z0) - 1.0) < eps:
                    # 1−z0 が存在
                    has_pair = True
            if abs((_z - z0) - 0.5) < eps or abs((_z - z0) + 0.5) < eps:
                has_pair_p05 = has_pair_p05 or abs((_z - z0) - 0.5) < eps
                has_pair_m05 = has_pair_m05 or abs((_z - z0) + 0.5) < eps

        # 「1−z しか無い」ケース（z0 自身が無く 1−z0 だけある等）を厳密に見たい場合は
        # 入れ替えた代表値でもう一度チェックする等の拡張も可
        if not has_pair:
            for _z in z_set:
                if abs(_z - (1.0 - z0)) < eps:
                    has_only_1_minus_z = True
                    break

        print(f" has_pair={has_pair} has_pair_±0.5={has_pair_p05}/{has_pair_m05}")

        marker_color = 'white'
        label_text = ''

        if has_z_eq_0:
            marker_color = 'red'
            label_text = ''
        elif has_pair:
            marker_color = 'yellow'
            label_text = '+-'
        elif has_pair_p05:
            marker_color = 'yellow'
            label_text = '+1/2'
        elif has_pair_m05:
            marker_color = 'yellow'
            label_text = '-1/2'
        elif has_only_1_minus_z:
            marker_color = 'gray'
            label_text = '-'

        ax.scatter(x, y, marker='o', s=200, facecolor=marker_color, edgecolor='black')
        ax.text(x, y, label_text, ha='center', va='center', color='k', fontsize=12)

        i += 1

    ax.set_aspect('equal', 'box')
    ax.set_xlim(-0.1,1.1)
    ax.set_ylim(-0.1,1.1)
    ax.set_xlabel('x (fractional)')
    ax.set_ylabel('y (fractional)')
    spg_group = SpaceGroup.from_int_number(ispg)
    ax.set_title(f"Equivalent Positions in Unit Cell (Space Group {ispg} — {spg_group.symbol})")
    plt.tight_layout()
    plt.show(block=False)


def main():
    parser = initialize()
    args = parser.parse_args()

    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit(0)

    if args.mode == 'list':
        display_spg_info(None)
    elif args.mode == 'inf':
        if args.ispg is None:
            print("Error: --ispg is required for 'inf' mode.")
            sys.exit(1)
        display_spg_info(args.ispg)
    elif args.mode == 'stereo':
        if args.ispg is None:
            print("Error: --ispg is required for 'stereo' mode.")
            sys.exit(1)
        # expand_coordinatesで座標を取得し、plot_unit_cell_projectionに渡す
        expanded_points = expand_coordinates(args.ispg, args.xyz, args.rmin)
        plot_unit_cell_projection(args.ispg, expanded_points)
    elif args.mode == 'expand':
        if args.ispg is None:
            print("Error: --ispg is required for 'expand' mode.")
            sys.exit(1)
        expand_coordinates(args.ispg, args.xyz, args.rmin)

    if 'matplotlib.pyplot' in sys.modules and plt.get_fignums():
        input("\nPress ENTER to terminate>>")


if __name__ == '__main__':
    main()
    input("\nPress ENTER to terminate>>\n")
    