import numpy as np
from math import acos, pi



# -------------------------
# user inputs
# -------------------------
#    d_orbital = "d_xy"   # "d_z2","d_x2_y2","d_xy","d_xz","d_yz","d_zx" (zxはxz扱いでもOK)
#    d_orbital = "d_z2"
d_orbital = "d_xy"
#    d_orbital = "d_xz"
ligands_pos = np.array([
        [ 1.0, 0.0, 0.0],
        [-1.0, 0.0, 0.0],
        [ 0.0, 1.0, 0.0],
        [ 0.0,-1.0, 0.0],
        [ 0.0, 0.0, 1.0],
#        [ 0.0, 0.0,-1.0],
    ], float)

#    mode = "full"   # "s", "p", "sigma", "full"
mode = "p"   # "s", "p", "sigma", "full"

# tolerance for symmetry/matching (structure qualityに合わせて調整)
tol_match_geom = 5e-3   # point-group detection (positions)
tol_match_D    = 5e-3   # D(g) construction (positions)

# SALC extraction/printing controls
salc_eig_tol = 1e-6      # eigenvalue threshold for projection operator
coeff_tol    = 1e-6      # multiplicity threshold
print_coeffs = True
print_salc_thr = 1e-3    # coefficient threshold to show in SALC
max_salc_per_irrep = None  # e.g. 10 (None means no limit)



# ============================================================
# 0. Utilities
# ============================================================

def normalize(v, eps=1e-12):
    n = np.linalg.norm(v)
    if n < eps:
        return v * 0.0
    return v / n

def angle_between(u, v):
    u = normalize(u); v = normalize(v)
    c = float(np.clip(np.dot(u, v), -1.0, 1.0))
    return acos(c)

def best_right_handed_axes(ax, ay, az):
    """
    Given three (approximately) orthogonal axes, enforce right-handedness.
    Return a 3x3 rotation matrix U with columns [x,y,z] (orthonormal).
    """
    x = normalize(ax)
    y = normalize(ay - np.dot(ay, x) * x)
    y = normalize(y)
    z = normalize(np.cross(x, y))
    # if z opposite to az, flip y
    if np.dot(z, az) < 0:
        y = -y
        z = -z
    U = np.column_stack([x, y, z])
    return U

def orthonormalize(U):
    # polar decomposition: U -> closest rotation
    # U = (U (U^T U)^(-1/2))
    M = U.T @ U
    w, V = np.linalg.eigh(M)
    inv_sqrt = V @ np.diag(1.0 / np.sqrt(np.clip(w, 1e-15, None))) @ V.T
    R = U @ inv_sqrt
    # enforce det=+1
    if np.linalg.det(R) < 0:
        R[:, 0] *= -1
    return R

def pair_opposites(vs, cos_tol=-0.90):
    """
    Given vectors vs (N,3), find opposite pairs by greedy matching on dot product.
    Returns list of pairs (i,j) with dot(v_i,v_j) < cos_tol, and a score.
    """
    N = len(vs)
    used = set()
    pairs = []
    score = 0.0
    dots = []
    for i in range(N):
        for j in range(i+1, N):
            dots.append((np.dot(vs[i], vs[j]), i, j))
    dots.sort(key=lambda x: x[0])  # most negative first
    for d, i, j in dots:
        if i in used or j in used:
            continue
        if d < cos_tol:
            used.add(i); used.add(j)
            pairs.append((i, j))
            score += -d
    return pairs, score

def match_ligand_id_by_position(ligands_pos, pos_new, tol=1e-4):
    """
    ligands_pos: (N,3)
    pos_new: (3,)
    returns index of ligand that matches within tol, else None
    """
    d2 = np.sum((ligands_pos - pos_new[None, :])**2, axis=1)
    j = int(np.argmin(d2))
    if d2[j] <= tol**2:
        return j
    return None

# ============================================================
# 1. Group builders (Oh, C4v)
# ============================================================

def generate_O_rotations():
    """Proper rotations of cube (24)."""
    mats = []
    axes = np.eye(3, dtype=int)
    perms = []
    for p in [(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)]:
        perms.append(axes[:, p])

    for P in perms:
        for sx in [-1, 1]:
            for sy in [-1, 1]:
                for sz in [-1, 1]:
                    S = np.diag([sx, sy, sz])
                    R = P @ S
                    if int(round(np.linalg.det(R))) == 1:
                        mats.append(R.astype(int))

    uniq = []
    for R in mats:
        if not any(np.array_equal(R, Q) for Q in uniq):
            uniq.append(R)
    assert len(uniq) == 24
    return [R.astype(float) for R in uniq]

def classify_O_rotation(R):
    tr = np.trace(R)
    c = (tr - 1.0) / 2.0
    c = max(-1.0, min(1.0, c))
    theta = acos(c)

    def near(a, b, tol=1e-6):
        return abs(a - b) < tol

    if near(theta, 0.0):
        return "E"
    if near(theta, 2*pi/3) or near(theta, 4*pi/3):
        return "C3"
    if near(theta, pi/2) or near(theta, 3*pi/2):
        return "C4"
    if near(theta, pi):
        w, v = np.linalg.eig(R)
        idx = np.argmin(np.abs(w - 1.0))
        axis = normalize(np.real(v[:, idx]))
        coord = [
            np.array([1,0,0], float), np.array([-1,0,0], float),
            np.array([0,1,0], float), np.array([0,-1,0], float),
            np.array([0,0,1], float), np.array([0,0,-1], float),
        ]
        if any(abs(np.dot(axis, a)) > 1 - 1e-6 for a in coord):
            return "C2"
        else:
            return "C2p"
    raise RuntimeError("Unknown O rotation class")

def build_Oh_group():
    """Oh = O x Ci (48)."""
    Orots = generate_O_rotations()
    I = np.eye(3)
    inv = -I
    elements = []
    for k, R in enumerate(Orots):
        cls = classify_O_rotation(R)
        elements.append({"name": f"R{k:02d}",  "R": R,        "inv": False, "class": cls})
        elements.append({"name": f"iR{k:02d}", "R": inv @ R,  "inv": True,  "class": cls})
    return elements

# ---- C4v operations (z axis is principal axis) ----
def Rz(theta):
    c, s = np.cos(theta), np.sin(theta)
    return np.array([[ c,-s, 0.0],
                     [ s, c, 0.0],
                     [0.0,0.0,1.0]], float)

def mirror_xz():  # y -> -y
    return np.array([[1.0, 0.0, 0.0],
                     [0.0,-1.0, 0.0],
                     [0.0, 0.0, 1.0]], float)

def mirror_yz():  # x -> -x
    return np.array([[-1.0, 0.0, 0.0],
                     [ 0.0, 1.0, 0.0],
                     [ 0.0, 0.0, 1.0]], float)

def mirror_diag_xy_plus():  # swap x,y
    return np.array([[0.0, 1.0, 0.0],
                     [1.0, 0.0, 0.0],
                     [0.0, 0.0, 1.0]], float)

def mirror_diag_xy_minus():  # (x,y)->(-y,-x)
    return np.array([[ 0.0,-1.0, 0.0],
                     [-1.0, 0.0, 0.0],
                     [ 0.0, 0.0, 1.0]], float)

def build_C4v_group():
    E  = np.eye(3)
    C4 = Rz(np.pi/2)
    C2 = Rz(np.pi)
    C4_3 = Rz(3*np.pi/2)

    sv1 = mirror_xz()
    sv2 = mirror_yz()
    sd1 = mirror_diag_xy_plus()
    sd2 = mirror_diag_xy_minus()

    return [
        {"name":"E",      "R":E,    "class":"E"},
        {"name":"C4",     "R":C4,   "class":"C4"},
        {"name":"C2",     "R":C2,   "class":"C2"},
        {"name":"C4^3",   "R":C4_3, "class":"C4"},
        {"name":"sv(xz)", "R":sv1,  "class":"sv"},
        {"name":"sv(yz)", "R":sv2,  "class":"sv"},
        {"name":"sd(x=y)","R":sd1,  "class":"sd"},
        {"name":"sd(x=-y)","R":sd2, "class":"sd"},
    ]

# ============================================================
# 2. Character tables + center d irreps
# ============================================================

# --- Oh (via O x parity) ---
O_character_table = {
    "A1": {"E": 1, "C3": 1,  "C4": 1,  "C2": 1,  "C2p": 1},
    "A2": {"E": 1, "C3": 1,  "C4": -1, "C2": 1,  "C2p": -1},
    "E":  {"E": 2, "C3": -1, "C4": 0,  "C2": 2,  "C2p": 0},
    "T1": {"E": 3, "C3": 0,  "C4": 1,  "C2": -1, "C2p": -1},
    "T2": {"E": 3, "C3": 0,  "C4": -1, "C2": -1, "C2p": 1},
}

def oh_irreps():
    reps = []
    for base in ["A1","A2","E","T1","T2"]:
        reps.append(base + "g")
        reps.append(base + "u")
    return reps

def oh_irrep_dim(irrep):
    return O_character_table[irrep[:-1]]["E"]

def oh_irrep_char(irrep, element):
    base = irrep[:-1]
    parity = +1 if irrep.endswith("g") else -1
    chiR = O_character_table[base][element["class"]]
    return (parity if element.get("inv", False) else 1) * chiR

d_orbital_irreps_Oh = {
    "d_z2":     "Eg",
    "d_x2_y2":  "Eg",
    "d_xy":     "T2g",
    "d_yz":     "T2g",
    "d_zx":     "T2g",
}

# --- C4v ---
C4v_character_table = {
    #          E  C4  C2  sv  sd
    "A1": {"E":1, "C4": 1, "C2": 1, "sv": 1, "sd": 1},
    "A2": {"E":1, "C4": 1, "C2": 1, "sv":-1, "sd":-1},
    "B1": {"E":1, "C4":-1, "C2": 1, "sv": 1, "sd":-1},
    "B2": {"E":1, "C4":-1, "C2": 1, "sv":-1, "sd": 1},
    "E" : {"E":2, "C4": 0, "C2":-2, "sv": 0, "sd": 0},
}

def c4v_irreps():
    return list(C4v_character_table.keys())

def c4v_irrep_dim(irrep):
    return C4v_character_table[irrep]["E"]

def c4v_irrep_char(irrep, element):
    return C4v_character_table[irrep][element["class"]]

d_orbital_irreps_C4v = {
    "d_z2":    "A1",
    "d_x2_y2": "B1",
    "d_xy":    "B2",
    "d_xz":    "E",
    "d_yz":    "E",
}

# ============================================================
# 3. Geometry: compute local frames & basis
# ============================================================

def compute_local_frames(ligands_pos):
    """
    ligands_pos: (N,3) positions from metal center
    frames[i] = (ex, ey, ez) for ligand i; ez along r (sigma)
    """
    frames = {}
    for i, p in enumerate(ligands_pos):
        ez = normalize(p)
        seed = np.array([1.0, 0.0, 0.0])
        if abs(np.dot(seed, ez)) > 0.9:
            seed = np.array([0.0, 1.0, 0.0])
        ey = normalize(np.cross(ez, seed))
        ex = normalize(np.cross(ey, ez))
        frames[i] = (ex, ey, ez)
    return frames

def build_ligand_basis(N_ligands, frames, mode="full"):
    basis = []
    for lid in range(N_ligands):
        if mode == "s":
            basis.append({"ligand": lid, "type": "s", "name": "s"})
        elif mode == "p":
            for orb in ["px", "py", "pz"]:
                basis.append({"ligand": lid, "type": "p", "name": orb})
        elif mode == "sigma":
            basis.append({"ligand": lid, "type": "lp", "name": "sigma"})
        elif mode == "full":
            basis.append({"ligand": lid, "type": "lp", "name": "sigma"})
            basis.append({"ligand": lid, "type": "lp", "name": "pi1"})
            basis.append({"ligand": lid, "type": "lp", "name": "pi2"})
        else:
            raise ValueError("Unknown basis mode")
    return basis

def basis_key(b):
    return (b["ligand"], b["type"], b["name"])

def basis_vector_global(b, frames):
    if b["type"] == "s":
        return None
    if b["type"] == "p":
        if b["name"] == "px": return np.array([1.0,0.0,0.0])
        if b["name"] == "py": return np.array([0.0,1.0,0.0])
        if b["name"] == "pz": return np.array([0.0,0.0,1.0])
        raise ValueError("bad p name")
    if b["type"] == "lp":
        ex, ey, ez = frames[b["ligand"]]
        if b["name"] == "sigma": return ez
        if b["name"] == "pi1":   return ex
        if b["name"] == "pi2":   return ey
        raise ValueError("bad lp name")
    raise ValueError("bad basis type")

def destination_components(mode, frames, dest_ligand_id):
    ex, ey, ez = frames[dest_ligand_id]
    if mode == "s":
        return [("s", None)]
    if mode == "p":
        return [("px", np.array([1.0,0,0])), ("py", np.array([0,1.0,0])), ("pz", np.array([0,0,1.0]))]
    if mode == "sigma":
        return [("sigma", ez)]
    if mode == "full":
        return [("sigma", ez), ("pi1", ex), ("pi2", ey)]
    raise ValueError("Unknown mode")

# ============================================================
# 4. Representation matrix with real-coordinate ligand matching
# ============================================================

def representation_matrix_for_element(group_element, ligands_pos, frames, basis, mode, tol_match=1e-4):
    R = group_element["R"]
    n = len(basis)
    D = np.zeros((n, n), float)
    idx = {basis_key(b): i for i, b in enumerate(basis)}

    for i, b in enumerate(basis):
        lid = b["ligand"]
        r0 = ligands_pos[lid]
        r1 = R @ r0

        dest_lid = match_ligand_id_by_position(ligands_pos, r1, tol=tol_match)
        if dest_lid is None:
            continue

        if b["type"] == "s":
            j = idx.get((dest_lid, "s", "s"), None)
            if j is not None:
                D[j, i] = 1.0
            continue

        u0 = basis_vector_global(b, frames)
        u1 = R @ u0

        for name, v in destination_components(mode, frames, dest_lid):
            if mode == "p":
                coeff = float(np.dot(u1, v))
                key = (dest_lid, "p", name)
            else:
                coeff = float(np.dot(u1, v))
                key = (dest_lid, "lp", name)

            j = idx.get(key, None)
            if j is not None and abs(coeff) > 1e-12:
                D[j, i] = coeff

    return D

def character_of_rep(group_element, ligands_pos, frames, basis, mode, tol_match=1e-4):
    D = representation_matrix_for_element(group_element, ligands_pos, frames, basis, mode, tol_match=tol_match)
    return float(np.trace(D))

def build_reducible_characters(group, ligands_pos, frames, basis, mode, tol_match=1e-4):
    return np.array([character_of_rep(g, ligands_pos, frames, basis, mode, tol_match=tol_match) for g in group], float)

# ============================================================
# 5. Decomposition / Projection / SALCs (generic)
# ============================================================

def decompose_generic(Gamma, group, irreps, irrep_char_func):
    Gorder = len(group)
    coeffs = {}
    for ir in irreps:
        chis = np.array([irrep_char_func(ir, g) for g in group], float)
        coeffs[ir] = float(np.dot(Gamma, chis) / Gorder)
    return coeffs

def projection_operator(irrep, group, irrep_dim_func, irrep_char_func,
                        ligands_pos, frames, basis, mode, tol_match=1e-4):
    l = irrep_dim_func(irrep)
    Gorder = len(group)
    n = len(basis)
    P = np.zeros((n, n), float)
    for g in group:
        chi = irrep_char_func(irrep, g)
        Dg = representation_matrix_for_element(g, ligands_pos, frames, basis, mode, tol_match=tol_match)
        P += chi * Dg
    P *= (l / Gorder)
    return P

def extract_salc(irrep, group, irrep_dim_func, irrep_char_func,
                 ligands_pos, frames, basis, mode, tol_match=1e-4, tol=1e-6):
    P = projection_operator(irrep, group, irrep_dim_func, irrep_char_func,
                            ligands_pos, frames, basis, mode, tol_match=tol_match)
    w, v = np.linalg.eigh(P)
    salcs = []
    for k in range(len(w)):
        if w[k] > tol:
            vec = v[:, k]
            vec = vec / np.linalg.norm(vec)
            imax = int(np.argmax(np.abs(vec)))
            if vec[imax] < 0:
                vec = -vec
            salcs.append(vec)
    return salcs

def print_salc(irrep, salcs, basis, thr=1e-3):
    print(f"\n[SALCs for {irrep}]  (count={len(salcs)})")
    for k, vec in enumerate(salcs, 1):
        print(f"  SALC #{k}:")
        items = sorted(zip(vec, basis), key=lambda x: (x[1]["ligand"], x[1]["type"], x[1]["name"]))
        for c, b in items:
            if abs(c) < thr:
                continue
            print(f"    {c:+.4f}  L{b['ligand']}:{b['name']}")

# ============================================================
# 6. Auto-detect point group (Oh vs C4v) + align to canonical axes
# ============================================================

def align_guess_Oh(ligands_pos):
    """
    Try to find a canonical orientation for an (approx.) octahedral environment:
    - expects N >= 6, but will use the 6 largest-|r| ligands (common in files)
    Returns (ok, R_align, pos_aligned)
    """
    N = ligands_pos.shape[0]
    if N < 6:
        return False, np.eye(3), ligands_pos.copy()

    # choose 6 strongest bonds
    norms = np.linalg.norm(ligands_pos, axis=1)
    idx6 = np.argsort(-norms)[:6]
    r6 = ligands_pos[idx6]
    u6 = np.array([normalize(v) for v in r6])

    pairs, _ = pair_opposites(u6, cos_tol=-0.92)
    if len(pairs) != 3:
        return False, np.eye(3), ligands_pos.copy()

    axes = []
    for i, j in pairs:
        a = normalize(u6[i] - u6[j])  # along pair axis
        axes.append(a)

    # check near orthogonality
    dots = [abs(np.dot(axes[0], axes[1])),
            abs(np.dot(axes[0], axes[2])),
            abs(np.dot(axes[1], axes[2]))]
    if max(dots) > 0.25:  # ~ > 75 degrees
        return False, np.eye(3), ligands_pos.copy()

    # build alignment: map found axes -> canonical x,y,z
    # choose one axis as z (arbitrary but stable): the one closest to global z
    gz = np.array([0.0, 0.0, 1.0])
    k_z = int(np.argmax([abs(np.dot(a, gz)) for a in axes]))
    az = axes[k_z]
    rem = [axes[i] for i in range(3) if i != k_z]
    ax = rem[0]
    ay = rem[1]
    U = best_right_handed_axes(ax, ay, az)     # columns are local axes in original frame
    R_align = U.T                               # so that R_align @ r puts axes to xyz
    pos_aligned = (R_align @ ligands_pos.T).T
    return True, R_align, pos_aligned

def align_guess_C4v(ligands_pos):
    """
    Guess a C4v canonical orientation:
    - find principal axis as direction that yields (approximately) 4 equatorial ligands.
    Works for N=5 (square pyramidal) and N=6 (elongated / missing inversion), if it is C4v-like.
    Returns (ok, R_align, pos_aligned)
    """
    N = ligands_pos.shape[0]
    if N < 5:
        return False, np.eye(3), ligands_pos.copy()

    u = np.array([normalize(v) for v in ligands_pos])
    candidates = []

    # candidate axes from ligand directions and pair differences
    for i in range(N):
        candidates.append(u[i])
        candidates.append(-u[i])
    for i in range(N):
        for j in range(i+1, N):
            w = normalize(u[i] - u[j])
            if np.linalg.norm(w) > 0:
                candidates.append(w)
                candidates.append(-w)

    # score: want ~4 ligands near perpendicular to axis and remaining aligned
    best = None
    for a in candidates:
        dots = np.abs(u @ a)
        # equatorial: small |dot| ; axial: large |dot|
        eq = np.sum(dots < 0.35)      # ~ angle > 69 deg
        ax = np.sum(dots > 0.85)      # ~ angle < 32 deg
        # typical C4v: eq ≈ 4, ax ≈ 1 (square pyramid) or 2 (elongated)
        score = -abs(eq - 4) - 0.5*min(abs(ax - 1), abs(ax - 2))
        if best is None or score > best[0]:
            best = (score, a, eq, ax)

    if best is None:
        return False, np.eye(3), ligands_pos.copy()

    _, z_axis, eq, ax = best
    if eq < 4:
        return False, np.eye(3), ligands_pos.copy()

    # choose x axis from one equatorial ligand projected onto plane ⟂ z
    dots = np.abs(u @ z_axis)
    eq_ids = np.where(dots < 0.35)[0]
    if len(eq_ids) < 4:
        return False, np.eye(3), ligands_pos.copy()

    v0 = ligands_pos[eq_ids[0]]
    x_axis = v0 - np.dot(v0, z_axis) * z_axis
    if np.linalg.norm(x_axis) < 1e-8:
        # try another equatorial
        for k in eq_ids[1:]:
            v0 = ligands_pos[k]
            x_axis = v0 - np.dot(v0, z_axis) * z_axis
            if np.linalg.norm(x_axis) > 1e-8:
                break
    if np.linalg.norm(x_axis) < 1e-8:
        return False, np.eye(3), ligands_pos.copy()

    y_axis = np.cross(z_axis, x_axis)
    U = best_right_handed_axes(x_axis, y_axis, z_axis)
    R_align = U.T
    pos_aligned = (R_align @ ligands_pos.T).T
    return True, R_align, pos_aligned

def symmetry_hit_rate(group, ligands_pos_aligned, tol_match=1e-3):
    """
    Count how many group operations map the ligand set onto itself within tol_match.
    (We test only positions, not orbital orientations.)
    """
    N = ligands_pos_aligned.shape[0]
    hits = 0
    for g in group:
        R = g["R"]
        ok = True
        for i in range(N):
            p_new = R @ ligands_pos_aligned[i]
            j = match_ligand_id_by_position(ligands_pos_aligned, p_new, tol=tol_match)
            if j is None:
                ok = False
                break
        if ok:
            hits += 1
    return hits / len(group)

def detect_point_group(ligands_pos, tol_match=1e-3):
    """
    Auto-detect between Oh and C4v.
    Returns (point_group, R_align, pos_aligned, group_spec_dict)
    """
    # try Oh alignment
    ok_Oh, R_Oh, pos_Oh = align_guess_Oh(ligands_pos)
    Oh_group = build_Oh_group()
    rate_Oh = symmetry_hit_rate(Oh_group, pos_Oh if ok_Oh else ligands_pos, tol_match=tol_match) if ok_Oh else 0.0

    # try C4v alignment
    ok_C4v, R_C4v, pos_C4v = align_guess_C4v(ligands_pos)
    C4v_group = build_C4v_group()
    rate_C4v = symmetry_hit_rate(C4v_group, pos_C4v if ok_C4v else ligands_pos, tol_match=tol_match) if ok_C4v else 0.0

    # decide
    # (Oh is stricter: require decent hit rate; otherwise pick C4v if it fits)
    if rate_Oh >= 0.70 and rate_Oh >= rate_C4v:
        spec = {
            "name": "Oh",
            "group": Oh_group,
            "irreps": oh_irreps(),
            "irrep_dim": oh_irrep_dim,
            "irrep_char": oh_irrep_char,
            "d_irreps": d_orbital_irreps_Oh,
        }
        return "Oh", R_Oh, pos_Oh, spec

    if rate_C4v >= 0.80:
        spec = {
            "name": "C4v",
            "group": C4v_group,
            "irreps": c4v_irreps(),
            "irrep_dim": c4v_irrep_dim,
            "irrep_char": c4v_irrep_char,
            "d_irreps": d_orbital_irreps_C4v,
        }
        return "C4v", R_C4v, pos_C4v, spec

    # fallback: choose higher rate (still return something)
    if rate_Oh >= rate_C4v:
        spec = {
            "name": "Oh",
            "group": Oh_group,
            "irreps": oh_irreps(),
            "irrep_dim": oh_irrep_dim,
            "irrep_char": oh_irrep_char,
            "d_irreps": d_orbital_irreps_Oh,
        }
        return "Oh", (R_Oh if ok_Oh else np.eye(3)), (pos_Oh if ok_Oh else ligands_pos.copy()), spec
    else:
        spec = {
            "name": "C4v",
            "group": C4v_group,
            "irreps": c4v_irreps(),
            "irrep_dim": c4v_irrep_dim,
            "irrep_char": c4v_irrep_char,
            "d_irreps": d_orbital_irreps_C4v,
        }
        return "C4v", (R_C4v if ok_C4v else np.eye(3)), (pos_C4v if ok_C4v else ligands_pos.copy()), spec

# ============================================================
# 7. Main: unified run
# ============================================================

def main():
    global ligands_pos
    
    # -------------------------
    # normalize / center (safety)
    # -------------------------
    ligands_pos = ligands_pos - np.mean(ligands_pos, axis=0, keepdims=True)

    # -------------------------
    # auto-detect point group + align
    # -------------------------
    pg, R_align, pos_aligned, spec = detect_point_group(ligands_pos, tol_match=tol_match_geom)

    print(f"\n[Auto-detected point group]  {pg}")
    print(f"[Symmetry axis convention]  canonical axes are used internally (aligned geometry).")

    # handle d_zx alias
    if d_orbital == "d_zx":
        d_orbital_use = "d_xz"
    else:
        d_orbital_use = d_orbital

    if d_orbital_use not in spec["d_irreps"]:
        raise ValueError(f"d orbital '{d_orbital}' not defined for point group {pg}")

    d_irrep = spec["d_irreps"][d_orbital_use]
    print("\n[Central metal d orbital]")
    print(f"  Orbital     : {d_orbital}")
    print(f"  Irrep       : {d_irrep}")

    # -------------------------
    # build frames & basis
    # -------------------------
    frames = compute_local_frames(pos_aligned)
    basis  = build_ligand_basis(len(pos_aligned), frames, mode=mode)

    group = spec["group"]

    # -------------------------
    # reducible rep & decomposition
    # -------------------------
    Gamma = build_reducible_characters(group, pos_aligned, frames, basis, mode=mode, tol_match=tol_match_D)
    coeffs = decompose_generic(Gamma, group, spec["irreps"], spec["irrep_char"])

    if print_coeffs:
        print(f"\n[Ligand reducible rep decomposition ({pg})]")
        for ir in spec["irreps"]:
            if coeffs.get(ir, 0.0) > coeff_tol:
                print(f"  {ir}: {coeffs[ir]:.3f}")

    # -------------------------
    # hybridization check (still show)
    # -------------------------
    print("\n[Hybridization check]")
    if coeffs.get(d_irrep, 0.0) > coeff_tol:
        print(f"  ✔ Hybridization allowed by symmetry: ligand contains {d_irrep}")
    else:
        print(f"  ✘ Symmetry forbidden: ligand does NOT contain {d_irrep}")

    # -------------------------
    # SALCs: show ALL irreps with positive multiplicity
    # -------------------------
    print(f"\n[All ligand SALCs ({pg})]  mode={mode}")
    any_printed = False

    for ir in spec["irreps"]:
        mult = coeffs.get(ir, 0.0)
        if mult <= coeff_tol:
            continue

        salcs = extract_salc(
            ir, group,
            spec["irrep_dim"], spec["irrep_char"],
            pos_aligned, frames, basis, mode=mode,
            tol_match=tol_match_D, tol=salc_eig_tol
        )

        if max_salc_per_irrep is not None:
            salcs = salcs[:max_salc_per_irrep]

        if len(salcs) == 0:
            # multiplicity>0 でも数値誤差で取りこぼす場合があるので注意喚起
            print(f"\n[SALCs for {ir}]  (expected mult≈{mult:.3f}, but none extracted; try lowering salc_eig_tol)")
            continue

        any_printed = True
        print_salc(ir, salcs, basis, thr=print_salc_thr)

    if not any_printed:
        print("  (No SALCs extracted. Try increasing tol_match_D or lowering salc_eig_tol.)")


if __name__ == "__main__":
    main()
