# tkpg/C4.py
import numpy as np
from tkpg import core

def Rz(theta):
    c, s = np.cos(theta), np.sin(theta)
    return np.array([[ c,-s, 0.0],
                     [ s, c, 0.0],
                     [0.0,0.0,1.0]], float)

def build_group():
    E  = np.eye(3)
    C4 = Rz(np.pi/2)
    C2 = Rz(np.pi)
    C4_3 = Rz(3*np.pi/2)
    return [
        {"name":"E",    "R":E,    "class":"E"},
        {"name":"C4",   "R":C4,   "class":"C4"},
        {"name":"C2",   "R":C2,   "class":"C2"},
        {"name":"C4^3", "R":C4_3, "class":"C4"},
    ]

# Character table for C4 (classes: E, C4, C2)
# Irreps: A, B, E  (here E is 2D)
C4_character_table = {
    "A": {"E":1, "C4": 1, "C2": 1},
    "B": {"E":1, "C4":-1, "C2": 1},
    "E": {"E":2, "C4": 0, "C2":-2},
}

def irreps():
    return list(C4_character_table.keys())

def irrep_dim(irrep):
    return C4_character_table[irrep]["E"]

def irrep_char(irrep, element):
    return C4_character_table[irrep][element["class"]]

# d-orbital irreps (C4)
# Typical: z2 -> A, x2-y2 -> B, xy -> B (depends on convention), xz/yz -> E
d_irreps = {
    "d_z2":    "A",
    "d_x2_y2": "B",
    "d_xy":    "B",
    "d_xz":    "E",
    "d_yz":    "E",
}

def align_guess(ligands_pos):
    """
    Simple guess: choose z as best principal axis from ligand directions
    similar to C4v, but we only need 4-fold axis; no mirror checks.
    """
    ligands_pos = np.array(ligands_pos, float)
    N = ligands_pos.shape[0]
    if N < 4:
        return False, np.eye(3), ligands_pos.copy()

    u = np.array([core.normalize(v) for v in ligands_pos])

    candidates = []
    for i in range(N):
        candidates.append(u[i]); candidates.append(-u[i])
    for i in range(N):
        for j in range(i+1, N):
            w = core.normalize(u[i] - u[j])
            if np.linalg.norm(w) > 0:
                candidates.append(w); candidates.append(-w)

    best = None
    for a in candidates:
        dots = np.abs(u @ a)
        eq = np.sum(dots < 0.35)      # near equatorial
        ax = np.sum(dots > 0.85)      # near axial
        score = -abs(eq - 4) - 0.5*min(abs(ax - 1), abs(ax - 2))
        if best is None or score > best[0]:
            best = (score, a)

    if best is None:
        return False, np.eye(3), ligands_pos.copy()

    z_axis = best[1]
    dots = np.abs(u @ z_axis)
    eq_ids = np.where(dots < 0.35)[0]
    if len(eq_ids) < 2:
        return False, np.eye(3), ligands_pos.copy()

    v0 = ligands_pos[eq_ids[0]]
    x_axis = v0 - np.dot(v0, z_axis) * z_axis
    if np.linalg.norm(x_axis) < 1e-8:
        for k in eq_ids[1:]:
            v0 = ligands_pos[k]
            x_axis = v0 - np.dot(v0, z_axis) * z_axis
            if np.linalg.norm(x_axis) > 1e-8:
                break
    if np.linalg.norm(x_axis) < 1e-8:
        return False, np.eye(3), ligands_pos.copy()

    y_axis = np.cross(z_axis, x_axis)

    # right-handed orthonormal frame
    x = core.normalize(x_axis)
    y = core.normalize(y_axis - np.dot(y_axis, x) * x)
    z = core.normalize(np.cross(x, y))
    U = np.column_stack([x, y, z])
    R_align = U.T
    pos_aligned = (R_align @ ligands_pos.T).T
    return True, R_align, pos_aligned


PLUGIN = {
    "name": "C4",
    "build_group": build_group,
    "irreps": irreps,
    "irrep_dim": irrep_dim,
    "irrep_char": irrep_char,
    "d_irreps": d_irreps,
    "align_guess": align_guess,
}
