# tkpg/C4v.py
import numpy as np
from tkpg import core

def Rz(theta):
    c, s = np.cos(theta), np.sin(theta)
    return np.array([[ c,-s, 0.0],
                     [ s, c, 0.0],
                     [0.0,0.0,1.0]], float)

def mirror_xz():  # y -> -y
    return np.array([[1.0, 0.0, 0.0],
                     [0.0,-1.0, 0.0],
                     [0.0, 0.0, 1.0]], float)

def mirror_yz():  # x -> -x
    return np.array([[-1.0, 0.0, 0.0],
                     [ 0.0, 1.0, 0.0],
                     [ 0.0, 0.0, 1.0]], float)

def mirror_diag_xy_plus():  # swap x,y
    return np.array([[0.0, 1.0, 0.0],
                     [1.0, 0.0, 0.0],
                     [0.0, 0.0, 1.0]], float)

def mirror_diag_xy_minus():  # (x,y)->(-y,-x)
    return np.array([[ 0.0,-1.0, 0.0],
                     [-1.0, 0.0, 0.0],
                     [ 0.0, 0.0, 1.0]], float)

def build_group():
    E  = np.eye(3)
    C4 = Rz(np.pi/2)
    C2 = Rz(np.pi)
    C4_3 = Rz(3*np.pi/2)

    sv1 = mirror_xz()
    sv2 = mirror_yz()
    sd1 = mirror_diag_xy_plus()
    sd2 = mirror_diag_xy_minus()

    return [
        {"name":"E",       "R":E,    "class":"E"},
        {"name":"C4",      "R":C4,   "class":"C4"},
        {"name":"C2",      "R":C2,   "class":"C2"},
        {"name":"C4^3",    "R":C4_3, "class":"C4"},
        {"name":"sv(xz)",  "R":sv1,  "class":"sv"},
        {"name":"sv(yz)",  "R":sv2,  "class":"sv"},
        {"name":"sd(x=y)", "R":sd1,  "class":"sd"},
        {"name":"sd(x=-y)","R":sd2,  "class":"sd"},
    ]

# C4v character table (classes: E, 2C4, C2, 2σv, 2σd)
C4v_character_table = {
    "A1": {"E":1, "C4": 1, "C2": 1, "sv": 1, "sd": 1},
    "A2": {"E":1, "C4": 1, "C2": 1, "sv":-1, "sd":-1},
    "B1": {"E":1, "C4":-1, "C2": 1, "sv": 1, "sd":-1},
    "B2": {"E":1, "C4":-1, "C2": 1, "sv":-1, "sd": 1},
    "E" : {"E":2, "C4": 0, "C2":-2, "sv": 0, "sd": 0},
}

def irreps():
    return list(C4v_character_table.keys())

def irrep_dim(irrep):
    return C4v_character_table[irrep]["E"]

def irrep_char(irrep, element):
    return C4v_character_table[irrep][element["class"]]

# Center d-orbital irreps in C4v
d_irreps = {
    "d_z2":    "A1",
    "d_x2_y2": "B1",
    "d_xy":    "B2",
    "d_xz":    "E",
    "d_yz":    "E",
}

def align_guess(ligands_pos):
    """
    Guess a C4v canonical orientation:
    - find principal axis as direction that yields ~4 equatorial ligands
    Works for N=5 (square pyramidal) and N>=4 generally if C4-like.
    """
    ligands_pos = np.array(ligands_pos, float)
    N = ligands_pos.shape[0]
    if N < 4:
        return False, np.eye(3), ligands_pos.copy()

    u = np.array([core.normalize(v) for v in ligands_pos])
    candidates = []

    for i in range(N):
        candidates.append(u[i]); candidates.append(-u[i])
    for i in range(N):
        for j in range(i+1, N):
            w = core.normalize(u[i] - u[j])
            if np.linalg.norm(w) > 0:
                candidates.append(w); candidates.append(-w)

    best = None
    for a in candidates:
        dots = np.abs(u @ a)
        eq = np.sum(dots < 0.35)
        ax = np.sum(dots > 0.85)
        score = -abs(eq - 4) - 0.5*min(abs(ax - 1), abs(ax - 2))
        if best is None or score > best[0]:
            best = (score, a, eq, ax)

    if best is None:
        return False, np.eye(3), ligands_pos.copy()

    _, z_axis, eq, _ax = best
    if eq < 4:
        return False, np.eye(3), ligands_pos.copy()

    dots = np.abs(u @ z_axis)
    eq_ids = np.where(dots < 0.35)[0]
    if len(eq_ids) < 2:
        return False, np.eye(3), ligands_pos.copy()

    v0 = ligands_pos[eq_ids[0]]
    x_axis = v0 - np.dot(v0, z_axis) * z_axis
    if np.linalg.norm(x_axis) < 1e-8:
        for k in eq_ids[1:]:
            v0 = ligands_pos[k]
            x_axis = v0 - np.dot(v0, z_axis) * z_axis
            if np.linalg.norm(x_axis) > 1e-8:
                break
    if np.linalg.norm(x_axis) < 1e-8:
        return False, np.eye(3), ligands_pos.copy()

    y_axis = np.cross(z_axis, x_axis)

    x = core.normalize(x_axis)
    y = core.normalize(y_axis - np.dot(y_axis, x) * x)
    z = core.normalize(np.cross(x, y))
    U = np.column_stack([x, y, z])
    R_align = U.T
    pos_aligned = (R_align @ ligands_pos.T).T
    return True, R_align, pos_aligned


PLUGIN = {
    "name": "C4v",
    "build_group": build_group,
    "irreps": irreps,
    "irrep_dim": irrep_dim,
    "irrep_char": irrep_char,
    "d_irreps": d_irreps,
    "align_guess": align_guess,
}
