# tkpg/Td.py
import numpy as np
from tkpg import core

# Td (order 24) can be generated as all orthogonal integer matrices
# (signed permutation matrices) that map tetrahedral vertex set onto itself.
# Use tetra vertices:
#   v1=(1,1,1), v2=(1,-1,-1), v3=(-1,1,-1), v4=(-1,-1,1)
V = np.array([
    [ 1,  1,  1],
    [ 1, -1, -1],
    [-1,  1, -1],
    [-1, -1,  1],
], float)

def generate_signed_permutation_matrices():
    mats = []
    axes = np.eye(3, dtype=int)
    perms = []
    for p in [(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)]:
        perms.append(axes[:, p])

    for P in perms:
        for sx in [-1, 1]:
            for sy in [-1, 1]:
                for sz in [-1, 1]:
                    S = np.diag([sx, sy, sz])
                    R = (P @ S).astype(int)
                    mats.append(R)

    uniq = []
    for R in mats:
        if not any(np.array_equal(R, Q) for Q in uniq):
            uniq.append(R)
    assert len(uniq) == 48
    return [R.astype(float) for R in uniq]

def maps_tetra_to_itself(R, tol=1e-9):
    W = (R @ V.T).T
    # each row of W must be one of rows of V
    used = [False]*len(V)
    for w in W:
        found = False
        for i, v in enumerate(V):
            if used[i]:
                continue
            if np.linalg.norm(w - v) < tol:
                used[i] = True
                found = True
                break
        if not found:
            return False
    return True

def build_group():
    mats = generate_signed_permutation_matrices()
    elems = []
    for k, R in enumerate(mats):
        if maps_tetra_to_itself(R):
            # classify by det only (enough for hit_rate / D(g))
            det = float(np.linalg.det(R))
            cls = "rot" if det > 0 else "ref"
            elems.append({"name": f"g{k:02d}", "R": R, "class": cls})
    # should be 24
    return elems

# Td character table is more involved; for now we provide a minimal working set
# for decomposition/projection in Td-like workflows you can extend this.
# Here: provide only "A1", "A2", "E", "T1", "T2" for the rotational subgroup T (12),
# but Td needs extra classes (S4, σd). Implementing full Td table is possible,
# however for many ligand-cluster tasks you mainly need group ops + D(g) + SALCs under that group.
#
# For safety: we provide a pseudo-table using only det-based class -> NOT rigorous decomposition.
# You should extend this if you need Td decomposition.
#
# Practically: keep Td as detection-only, and use C3v/C2v etc if you need full labels.
#
# ---- If you need full Td irreps later, tell me and I'll drop in the proper table. ----

Td_character_table = {
    "A1": {"rot": 1, "ref": 1},
    "A2": {"rot": 1, "ref":-1},
    "E":  {"rot": 2, "ref": 0},
    "T1": {"rot": 3, "ref":-1},
    "T2": {"rot": 3, "ref": 1},
}

def irreps():
    return list(Td_character_table.keys())

def irrep_dim(irrep):
    return Td_character_table[irrep]["rot"]  # dimension at identity-like

def irrep_char(irrep, element):
    return Td_character_table[irrep][element["class"]]

# d-orbitals in Td:
# (d_z2, d_x2_y2) -> E, (d_xy,d_xz,d_yz) -> T2  (common convention)
d_irreps = {
    "d_z2":     "E",
    "d_x2_y2":  "E",
    "d_xy":     "T2",
    "d_xz":     "T2",
    "d_yz":     "T2",
}

def align_guess(ligands_pos):
    """
    Td alignment is non-unique; for now return identity.
    Detection relies on hit_rate anyway.
    """
    ligands_pos = np.array(ligands_pos, float)
    return True, np.eye(3), ligands_pos.copy()


PLUGIN = {
    "name": "Td",
    "build_group": build_group,
    "irreps": irreps,
    "irrep_dim": irrep_dim,
    "irrep_char": irrep_char,
    "d_irreps": d_irreps,
    "align_guess": align_guess,
}
