# tkpg/Oh.py
import numpy as np
from math import acos, pi
from tkpg import core

def generate_O_rotations():
    """Proper rotations of cube (24)."""
    mats = []
    axes = np.eye(3, dtype=int)
    perms = []
    for p in [(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)]:
        perms.append(axes[:, p])

    for P in perms:
        for sx in [-1, 1]:
            for sy in [-1, 1]:
                for sz in [-1, 1]:
                    S = np.diag([sx, sy, sz])
                    R = P @ S
                    if int(round(np.linalg.det(R))) == 1:
                        mats.append(R.astype(int))

    uniq = []
    for R in mats:
        if not any(np.array_equal(R, Q) for Q in uniq):
            uniq.append(R)
    assert len(uniq) == 24
    return [R.astype(float) for R in uniq]

def classify_O_rotation(R):
    tr = np.trace(R)
    c = (tr - 1.0) / 2.0
    c = max(-1.0, min(1.0, c))
    theta = acos(c)

    def near(a, b, tol=1e-6):
        return abs(a - b) < tol

    if near(theta, 0.0):
        return "E"
    if near(theta, 2*pi/3) or near(theta, 4*pi/3):
        return "C3"
    if near(theta, pi/2) or near(theta, 3*pi/2):
        return "C4"
    if near(theta, pi):
        w, v = np.linalg.eig(R)
        idx = np.argmin(np.abs(w - 1.0))
        axis = core.normalize(np.real(v[:, idx]))
        coord = [
            np.array([1,0,0], float), np.array([-1,0,0], float),
            np.array([0,1,0], float), np.array([0,-1,0], float),
            np.array([0,0,1], float), np.array([0,0,-1], float),
        ]
        if any(abs(np.dot(axis, a)) > 1 - 1e-6 for a in coord):
            return "C2"
        else:
            return "C2p"
    raise RuntimeError("Unknown O rotation class")

def build_group():
    """Oh = O x Ci (48)."""
    Orots = generate_O_rotations()
    I = np.eye(3)
    inv = -I
    elements = []
    for k, R in enumerate(Orots):
        cls = classify_O_rotation(R)
        elements.append({"name": f"R{k:02d}",  "R": R,        "inv": False, "class": cls})
        elements.append({"name": f"iR{k:02d}", "R": inv @ R,  "inv": True,  "class": cls})
    return elements

# O-group classes order: E, 8C3, 6C4, 3C2, 6C2'
O_character_table = {
    "A1": {"E": 1, "C3": 1,  "C4": 1,  "C2": 1,  "C2p": 1},
    "A2": {"E": 1, "C3": 1,  "C4": -1, "C2": 1,  "C2p": -1},
    "E":  {"E": 2, "C3": -1, "C4": 0,  "C2": 2,  "C2p": 0},
    "T1": {"E": 3, "C3": 0,  "C4": 1,  "C2": -1, "C2p": -1},
    "T2": {"E": 3, "C3": 0,  "C4": -1, "C2": -1, "C2p": 1},
}

def irreps():
    reps = []
    for base in ["A1","A2","E","T1","T2"]:
        reps.append(base + "g")
        reps.append(base + "u")
    return reps

def irrep_dim(irrep):
    return O_character_table[irrep[:-1]]["E"]

def irrep_char(irrep, element):
    base = irrep[:-1]
    parity = +1 if irrep.endswith("g") else -1
    chiR = O_character_table[base][element["class"]]
    return (parity if element.get("inv", False) else 1) * chiR

d_irreps = {
    "d_z2":     "Eg",
    "d_x2_y2":  "Eg",
    "d_xy":     "T2g",
    "d_yz":     "T2g",
    "d_xz":     "T2g",
}

def align_guess(ligands_pos):
    """
    Rough octahedral alignment guess:
    - find three opposite pairs among top 6 ligands by |r|
    - build orthonormal axes from those pairs
    """
    ligands_pos = np.array(ligands_pos, float)
    N = ligands_pos.shape[0]
    if N < 6:
        return False, np.eye(3), ligands_pos.copy()

    norms = np.linalg.norm(ligands_pos, axis=1)
    idx6 = np.argsort(-norms)[:6]
    r6 = ligands_pos[idx6]
    u6 = np.array([core.normalize(v) for v in r6])

    # greedy opposite pairing
    used = set()
    pairs = []
    dots = []
    for i in range(6):
        for j in range(i+1, 6):
            dots.append((np.dot(u6[i], u6[j]), i, j))
    dots.sort(key=lambda x: x[0])  # most negative first

    for d, i, j in dots:
        if i in used or j in used:
            continue
        if d < -0.92:
            used.add(i); used.add(j)
            pairs.append((i, j))
        if len(pairs) == 3:
            break

    if len(pairs) != 3:
        return False, np.eye(3), ligands_pos.copy()

    axes = []
    for i, j in pairs:
        a = core.normalize(u6[i] - u6[j])
        axes.append(a)

    # choose z as closest to global z
    gz = np.array([0.0, 0.0, 1.0])
    k_z = int(np.argmax([abs(np.dot(a, gz)) for a in axes]))
    az = axes[k_z]
    rem = [axes[i] for i in range(3) if i != k_z]
    ax = rem[0]
    ay = rem[1]

    # orthonormalize
    x = core.normalize(ax)
    y = core.normalize(ay - np.dot(ay, x) * x)
    z = core.normalize(np.cross(x, y))
    if np.dot(z, az) < 0:
        y = -y
        z = -z

    U = np.column_stack([x, y, z])
    R_align = U.T
    pos_aligned = (R_align @ ligands_pos.T).T
    return True, R_align, pos_aligned


PLUGIN = {
    "name": "Oh",
    "build_group": build_group,
    "irreps": irreps,
    "irrep_dim": irrep_dim,
    "irrep_char": irrep_char,
    "d_irreps": d_irreps,
    "align_guess": align_guess,
}
