"""
python ewaldf_BM.py --cif MgO.cif --sx1 SX-1.txt --scale 1.01 --auto --all-sites --print-lattice

a軸+1%、c軸-1%（六方晶の異方スケール検証に）：
python ewaldf_BM.py --cif MgO.cif --sx1 SX-1.txt --scale-abc 1.01,1.00,0.99 --auto --all-sites --print-lattice
"""

# ewald_bornmayer_cli.py
import sys
import time
import argparse
import numpy as np
from math import sin, cos, exp, sqrt, log10, pi

# ===== tkcrystalbase（元コード相当の関数）=====
from tkcrystalbase import (
    cal_lattice_vectors, cal_metrics, cal_volume,
    cal_reciprocal_lattice_vectors, cal_reciprocal_lattice_parameters,
    distance, distance2
)

# ===== 物理定数 =====
e  = 1.602176634e-19      # C
e0 = 8.854418782e-12      # C^2 N^-1 m^-2
Ke = e*e/(4.0*pi*e0)      # J·m·C^-2
torad = pi/180.0
ANG2M = 1e-10
J_per_eV = e
eV_per_J = 1.0/e

# ======================================================================================
# CIF読み込み（pymatgen）→ lattice_parameters / sites
# ======================================================================================
from pymatgen.core import Structure

def cif_to_lattice_and_sites(filename, charge_map=None, zero_charge=False):
    structure = Structure.from_file(filename)
    a, b, c, alpha, beta, gamma = structure.lattice.parameters
    lattice_parameters = [float(a), float(b), float(c), float(alpha), float(beta), float(gamma)]

    sites = []
    for i, site in enumerate(structure.sites):
        specie = site.specie
        elem   = specie.symbol
        Z      = int(specie.Z)
        mass   = float(specie.atomic_mass)
        if zero_charge:
            q = 0.0
        elif charge_map and (elem in charge_map):
            q = float(charge_map[elem])
        else:
            q = float(getattr(specie, "oxi_state", 0.0) or 0.0)
        r     = 1.0
        color = "gray"
        pos_frac = np.array(site.frac_coords, dtype=float)
        sites.append([elem, f"{elem}{i+1}", Z, mass, q, r, color, pos_frac])
    return lattice_parameters, sites

def parse_charge_map(s):
    if not s: return {}
    out = {}
    for kv in s.split(','):
        k,v = kv.split(':',1)
        out[k.strip()] = float(v.strip())
    return out

# ======================================================================================
# 格子情報 / Ewald カットオフ / α自動
# ======================================================================================
def build_lattice_info(lattice_parameters):
    aij = np.asarray(cal_lattice_vectors(lattice_parameters), float)
    gij = cal_metrics(lattice_parameters)['gij']
    vol = cal_volume(aij)
    Raij  = np.asarray(cal_reciprocal_lattice_vectors(aij), float)
    Rlatt = cal_reciprocal_lattice_parameters(Raij)
    Rgij  = cal_metrics(Rlatt)['gij']
    Rvol  = cal_volume(Raij)
    return {'aij': aij, 'gij': gij, 'volume': vol, 'Raij': Raij, 'Rlatt': Rlatt, 'Rgij': Rgij, 'Rvolume': Rvol}

def estimate_cutoffs(lattice_parameters, lattice_info, alpha=0.3, prec=1e-5):
    gij    = lattice_info['gij']
    Rgij   = lattice_info['Rgij']
    Rlatt  = lattice_info['Rlatt']
    norder = -log10(prec)
    rdmax = (2.26 + 0.26 * norder) / alpha
    lsin = np.empty(3, float)
    lsin[0] = sin(torad*lattice_parameters[3])
    lsin[1] = sin(torad*lattice_parameters[4])
    lsin[2] = sin(torad*lattice_parameters[5])
    nrmax = np.empty(3, int)
    nrmax[0] = int(rdmax / sqrt(gij[0][0]*lsin[1]*lsin[2])) + 1
    nrmax[1] = int(rdmax / sqrt(gij[1][1]*lsin[2]*lsin[0])) + 1
    nrmax[2] = int(rdmax / sqrt(gij[2][2]*lsin[0]*lsin[1])) + 1

    G2max = alpha**2/pi**2 * ( -np.log(prec) )
    lsinR = np.empty(3, float)
    lsinR[0] = sin(torad*Rlatt[3])
    lsinR[1] = sin(torad*Rlatt[4])
    lsinR[2] = sin(torad*Rlatt[5])
    hgmax = np.empty(3, int)
    hgmax[0] = int(sqrt(G2max/(Rgij[0][0]*lsinR[1]*lsinR[2]))) + 1
    hgmax[1] = int(sqrt(G2max/(Rgij[1][1]*lsinR[0]*lsinR[2]))) + 1
    hgmax[2] = int(sqrt(G2max/(Rgij[2][2]*lsinR[0]*lsinR[1]))) + 1
    return rdmax, nrmax, G2max, hgmax

def estimate_term_counts(lattice_info, rdmax, G2max, nsites):
    vol  = lattice_info['volume']
    Rvol = lattice_info['Rvolume']
    real  = (4.0/3.0)*pi*(rdmax**3)/vol * nsites
    recip = (4.0/3.0)*pi*(G2max**1.5)/Rvol * nsites
    return real, recip

def auto_choose_alpha(lattice_parameters, lattice_info, prec=1e-5, nsites=1,
                      alpha_min=0.1, alpha_max=1.2, n_grid=40):
    best = None
    for t in range(n_grid):
        alpha = alpha_min + (alpha_max-alpha_min)*t/(n_grid-1)
        rdmax, nrmax, G2max, hgmax = estimate_cutoffs(lattice_parameters, lattice_info, alpha, prec)
        score = sum(estimate_term_counts(lattice_info, rdmax, G2max, nsites))
        if best is None or score < best[1]:
            best = (alpha, score, (rdmax, nrmax, G2max, hgmax))
    return best[0], best[2]

# ======================================================================================
# Ewald：ポテンシャル＆力
# ======================================================================================
def ewald_real_space_term_and_force(lattice_info, sites, nrmax, alpha, rmin_ang, i_index):
    from math import erfc
    aij = np.asarray(lattice_info['aij'], float)
    pos_i = sites[i_index][7]
    qi = sites[i_index][4]
    nsites = len(sites)
    UC1 = 0.0
    F = np.zeros(3, float)

    alpha_m = alpha * 1e10

    for iz in range(-nrmax[2], nrmax[2]+1):
        for iy in range(-nrmax[1], nrmax[1]+1):
            for ix in range(-nrmax[0], nrmax[0]+1):
                T = np.array([ix, iy, iz], float)
                for j in range(nsites):
                    pos_j = sites[j][7]
                    qj    = sites[j][4]
                    d_frac = pos_i - (pos_j + T)
                    r_vec_ang = aij.T @ d_frac
                    r_ang = np.linalg.norm(r_vec_ang)
                    if r_ang < rmin_ang:  # 自己 or 近傍のスキップ
                        continue
                    UC1 += qj * erfc(alpha * r_ang) / (r_ang * ANG2M)
                    r_vec_m = r_vec_ang * ANG2M
                    r_m = np.linalg.norm(r_vec_m)
                    G = erfc(alpha_m*r_m)/(r_m**3) + (2.0*alpha_m/np.sqrt(np.pi))*np.exp(-(alpha_m*r_m)**2)/(r_m**2)
                    F += qi * Ke * qj * G * r_vec_m
    return UC1, F

def ewald_reciprocal_space_term_and_force(lattice_info, sites, hgmax, alpha, i_index, halfspace=True):
    Raij = np.asarray(lattice_info['Raij'], float)
    Rgij = lattice_info['Rgij']
    vol  = lattice_info['volume'] * (ANG2M**3)
    pos_i = sites[i_index][7]
    qi = sites[i_index][4]
    nsites = len(sites)
    Kexp = (pi*pi)/(alpha*alpha)
    Krec = 1.0/(pi*vol)

    UC2 = 0.0
    F   = np.zeros(3, float)

    origin = np.array([0.0,0.0,0.0], float)
    l_range = (range(0, hgmax[2]+1) if halfspace else range(-hgmax[2], hgmax[2]+1))
    for l in l_range:
        for k in range(-hgmax[1], hgmax[1]+1):
            for h in range(-hgmax[0], hgmax[0]+1):
                G2 = distance2(origin, np.array([h,k,l], float), Rgij)
                if G2 == 0.0:
                    continue
                theta_i = 2.0*pi*(h*pos_i[0] + k*pos_i[1] + l*pos_i[2])
                cti = cos(theta_i); sti = np.sin(theta_i)
                C = 0.0; S = 0.0
                for j in range(nsites):
                    qj = sites[j][4]; pos = sites[j][7]
                    tj = 2.0*pi*(h*pos[0] + k*pos[1] + l*pos[2])
                    C += qj*np.cos(tj); S += qj*np.sin(tj)
                amp  = np.exp(-Kexp*G2)/(G2*1e20)  # → m^2
                fcal = cti*C + sti*S
                if halfspace and l!=0: fcal *= 2.0
                UC2 += Krec*amp*fcal
                gph = (S*cti - C*sti)
                if halfspace and l!=0: gph *= 2.0
                G_vec_anginv = h*Raij[0] + k*Raij[1] + l*Raij[2]  # Å^-1
                G_vec_minv   = G_vec_anginv*1e10                   # m^-1
                grad_phi_vec = Krec*amp*(2.0*pi)*gph*G_vec_minv    # 1/m^2
                F += - qi * Ke * grad_phi_vec
    return UC2, F

def ewald_self_term(sites, alpha, i_index):
    from math import sqrt
    qi = sites[i_index][4]
    return qi * 2.0*(alpha*1e10)/sqrt(pi)

def ewald_potential_force_at_site(lattice_parameters, sites, i_index=0,
                                  alpha=0.3, prec=1e-5, rmin=0.1,
                                  nrmax_override=None, hgmax_override=None,
                                  halfspace=True, auto_alpha=False, alpha_grid=40):
    lat = build_lattice_info(lattice_parameters)
    if auto_alpha:
        alpha, (rdmax, nrmax, G2max, hgmax) = auto_choose_alpha(lattice_parameters, lat, prec=prec, nsites=len(sites), n_grid=alpha_grid)
    else:
        rdmax, nrmax, G2max, hgmax = estimate_cutoffs(lattice_parameters, lat, alpha, prec)
    if nrmax_override is not None:
        nrmax = np.array([nrmax_override[0] or nrmax[0],
                          nrmax_override[1] or nrmax[1],
                          nrmax_override[2] or nrmax[2]], int)
    if hgmax_override is not None:
        hgmax = np.array([hgmax_override[0] or hgmax[0],
                          hgmax_override[1] or hgmax[1],
                          hgmax_override[2] or hgmax[2]], int)

    t0=time.time()
    UC1, F_real = ewald_real_space_term_and_force(lat, sites, nrmax, alpha, rmin, i_index)
    t1=time.time()
    UC2, F_reci = ewald_reciprocal_space_term_and_force(lat, sites, hgmax, alpha, i_index, halfspace=halfspace)
    t2=time.time()
    UC3 = ewald_self_term(sites, alpha, i_index)
    t3=time.time()
    MP = UC1 + UC2 - UC3        # 1/m
    F_tot = F_real + F_reci     # N
    return {
        'alpha': alpha,
        'MP': MP, 'UC1': UC1, 'UC2': UC2, 'UC3': UC3,
        'F_real': F_real, 'F_reci': F_reci, 'F_tot': F_tot,
        'timing': {'real': t1-t0, 'recip': t2-t1, 'total': t3-t0},
        'lattice': lat,
        'cutoffs': {'nrmax': np.array(nrmax,int), 'hgmax': np.array(hgmax,int)},
        'settings': {'prec':prec, 'rmin':rmin, 'halfspace':halfspace}
    }

# ======================================================================================
# SX-1 Born–Mayer（短距離）: 反発項のみ（既定）
# ======================================================================================
def read_SX1_potential(dbpath: str):
    """
    期待フォーマット（タブ区切り）：
    element  mass  charge  ai  bi  ci  ARAD
    ※本コードでは ai, bi, charge を使用（charge はオプションで使用）
    """
    data_dict = {}
    with open(dbpath, 'r', encoding='utf-8') as f:
        lines = [ln.strip() for ln in f.readlines() if ln.strip()]
    headers = lines[0].split('\t')
    for line in lines[1:]:
        vals = line.split('\t')
        if len(vals) < 7: continue
        el  = vals[0]
        mass = float(vals[1]); charge = float(vals[2])
        ai = float(vals[3]);   bi = float(vals[4]); ci = float(vals[5])
        AR = float(vals[6])
        data_dict[el] = {'mass':mass, 'charge':charge, 'ai':ai, 'bi':bi, 'ci':ci, 'ratom':AR}
    return data_dict

def born_mayer_params(pdb, el1, el2):
    p1 = pdb[el1]; p2 = pdb[el2]
    aij = p1['ai'] + p2['ai']       # Å
    bij = p1['bi'] + p2['bi']       # Å
    return aij, bij, p1['charge'], p2['charge']

# F0 は eV/Å が望ましい（元データでは J/Å を eV 変換）
F0_eV_per_A = 6.947700141e-21 * eV_per_J  # eV/Å

def born_mayer_energy_force_eV_A(r_ang, aij, bij):
    """
    V(r) = F0 * bij * exp((aij - r)/bij) [eV]
    dV/dr = -F0 * exp((aij - r)/bij)     [eV/Å]
    |F| = -dV/dr = F0 * exp((aij - r)/bij) [eV/Å]
    """
    expo = np.exp((aij - r_ang)/bij)
    V = F0_eV_per_A * bij * expo
    Fmag_eV_per_A = F0_eV_per_A * expo
    return V, Fmag_eV_per_A

def compute_short_range_born_mayer(lattice_parameters, sites, sx1_db,
                                   sr_cut=8.0, include_coulomb=False):
    """
    実空間で i<j と格子平行移動 R を走査して V_BM を評価
    - 戻り： per-site エネルギー（eV）と力（N）
    - include_coulomb=True のとき K0*q_i*q_j/r の項も追加（通常は False にしてください）
    """
    lat = build_lattice_info(lattice_parameters)
    aij = np.asarray(lat['aij'], float)

    # R 枚挙の範囲（球カット sr_cut の「確実カバー」用途）
    gij = lat['gij']
    lsin = np.empty(3,float)
    lsin[0]=sin(torad*lattice_parameters[3])
    lsin[1]=sin(torad*lattice_parameters[4])
    lsin[2]=sin(torad*lattice_parameters[5])
    nr = np.empty(3,int)
    nr[0] = int(sr_cut / sqrt(gij[0][0]*lsin[1]*lsin[2])) + 1
    nr[1] = int(sr_cut / sqrt(gij[1][1]*lsin[2]*lsin[0])) + 1
    nr[2] = int(sr_cut / sqrt(gij[2][2]*lsin[0]*lsin[1])) + 1

    n = len(sites)
    E_site_eV = np.zeros(n, float)
    F_site_N  = np.zeros((n,3), float)

    # Coulomb 係数（eV·Å）：K0 = e^2/(4πe0)/1e-10 / e
    K0_eV_A = Ke * eV_per_J / 1e-10

    for iz in range(-nr[2], nr[2]+1):
        for iy in range(-nr[1], nr[1]+1):
            for ix in range(-nr[0], nr[0]+1):
                T = np.array([ix,iy,iz], float)
                for i in range(n):
                    el_i, _, _, _, qi, _, _, pos_i = sites[i]
                    for j in range(i, n):  # i==j はセル外のRで有効になることがある
                        el_j, _, _, _, qj, _, _, pos_j = sites[j]
                        d_frac = (pos_i - (pos_j + T))
                        r_vec_ang = aij.T @ d_frac
                        r = np.linalg.norm(r_vec_ang)
                        if r < 1e-8 or r > sr_cut:
                            continue
                        # i==j & R==0 はスキップ（同一点）
                        if i==j and ix==0 and iy==0 and iz==0:
                            continue

                        # ペア係数
                        aij_bm, bij_bm, qdb_i, qdb_j = born_mayer_params(sx1_db, el_i, el_j)

                        # Born–Mayer 反発（eV, eV/Å）
                        V_eV, Fmag_eV_per_A = born_mayer_energy_force_eV_A(r, aij_bm, bij_bm)
                        V_eV_pair = V_eV

                        # （オプション）Coulomb を短距離に加える（通常は False）
                        Fmag_coul_eV_per_A = 0.0
                        if include_coulomb:
                            V_eV_pair += K0_eV_A * qi * qj / r
                            Fmag_coul_eV_per_A = K0_eV_A * abs(qi*qj) / (r*r)

                        # 作用反作用で分配（i<->j）
                        # 力ベクトル（i に作用）： +(|F|)* r_hat（反発）
                        r_hat = r_vec_ang / r
                        F_pair_eV_per_A_vec = (Fmag_eV_per_A + Fmag_coul_eV_per_A) * r_hat
                        # eV/Å → N
                        F_pair_N = F_pair_eV_per_A_vec * (J_per_eV/ANG2M)

                        # エネルギーは 1/2 を各原子へ
                        E_site_eV[i] += 0.5 * V_eV_pair
                        E_site_eV[j] += 0.5 * V_eV_pair

                        # 力：i は +F、j は -F（r_vec は i→j）
                        F_site_N[i] +=  F_pair_N
                        F_site_N[j] += -F_pair_N

    return E_site_eV, F_site_N

# ======================================================================================
# 表示
# ======================================================================================
def print_site_totals(i, sites, ewald_res, E_bm_eV_i, F_bm_i, show_forces=True):
    name, label, *_ = sites[i]
    qi = sites[i][4]

    # φ_i (1/m 相当) -> Ewald site energy = q_i * Ke * φ_i
    MP = ewald_res['MP']
    Ei_J_ew  = qi * Ke * MP
    Ei_eV_ew = Ei_J_ew * eV_per_J

    # Born–Mayer（この関数に渡している E_bm_eV_i は既に「サイト i 分（1/2配分済み）」）
    Ei_eV_bm = E_bm_eV_i
    Ei_J_bm  = Ei_eV_bm * J_per_eV

    # 合計（サイト i のエネルギー）
    Ei_J_tot  = Ei_J_ew + Ei_J_bm
    Ei_eV_tot = Ei_J_tot * eV_per_J

    # 力（合計）
    F_tot_N = ewald_res['F_tot'] + F_bm_i
    F_tot_eV_A = F_tot_N * (eV_per_J*ANG2M)

    print(f"\n[Site {i}] {name} ({label}) q={qi:+.6g}")
    print(f"  Ewald energy   : {Ei_J_ew: .6e} J  ({Ei_eV_ew: .6e} eV)")
    print(f"  Born–Mayer rep.: {Ei_J_bm: .6e} J  ({Ei_eV_bm: .6e} eV)")
    print(f"  TOTAL energy   : {Ei_J_tot: .6e} J  ({Ei_eV_tot: .6e} eV)")
    if show_forces:
        Fa = ewald_res['F_tot']; Fb = F_bm_i; Ft = F_tot_N
        print("  Forces [N] (also in eV/Å):")
        print(f"    Ewald     : [{Fa[0]: .6e}, {Fa[1]: .6e}, {Fa[2]: .6e}]  | [{(Fa*eV_per_J*ANG2M)[0]: .6e}, {(Fa*eV_per_J*ANG2M)[1]: .6e}, {(Fa*eV_per_J*ANG2M)[2]: .6e}]")
        print(f"    Born–Mayer: [{Fb[0]: .6e}, {Fb[1]: .6e}, {Fb[2]: .6e}]  | [{(Fb*eV_per_J*ANG2M)[0]: .6e}, {(Fb*eV_per_J*ANG2M)[1]: .6e}, {(Fb*eV_per_J*ANG2M)[2]: .6e}]")
        print(f"    TOTAL     : [{Ft[0]: .6e}, {Ft[1]: .6e}, {Ft[2]: .6e}]  | [{F_tot_eV_A[0]: .6e}, {F_tot_eV_A[1]: .6e}, {F_tot_eV_A[2]: .6e}]")

# ======================================================================================
# CLI
# ======================================================================================
def parse_triplet(s):
    parts = [p.strip() for p in s.split(',')]
    if len(parts)!=3: raise argparse.ArgumentTypeError("must be 'n1,n2,n3'")
    out=[]
    for p in parts:
        if p=='' or p=='0': out.append(None)
        else:
            v=int(p); 
            if v<0: raise argparse.ArgumentTypeError("components must be >=0")
            out.append(v)
    return tuple(out)

def parse_abc_scale(s):
    parts = [p.strip() for p in s.split(',')]
    if len(parts) != 3:
        raise argparse.ArgumentTypeError("scale-abc must be 'sa,sb,sc'")
    try:
        return tuple(float(x) for x in parts)
    except Exception:
        raise argparse.ArgumentTypeError("scale-abc components must be float")

def build_argparser():
    ap = argparse.ArgumentParser(description="Ewald (pymatgen CIF) + Born–Mayer repulsion (SX-1)")
    # 構造
    ap.add_argument('--cif', type=str, required=True, help='Input CIF (pymatgen)')
    ap.add_argument('--charge-map', type=parse_charge_map, default={}, help="Override charges, e.g. 'Zn:+2,O:-2'")
    ap.add_argument('--zero-charge', action='store_true', help='Set all site charges to 0.0')

    # --- Lattice scaling ---
    ap.add_argument('--scale', type=float, default=1.0,
                    help='Uniform scale for a,b,c (default: 1.0)')
    ap.add_argument('--scale-abc', type=str, default=None,
                    help="Axis-wise scale 'sa,sb,sc' (overrides --scale if set)")

    # Ewald
    ap.add_argument('--prec', type=float, default=1e-5)
    ap.add_argument('--alpha', type=float, default=None)
    ap.add_argument('--auto', action='store_true')
    ap.add_argument('--alpha-grid', type=int, default=40)
    ap.add_argument('--nrmax', type=parse_triplet, default=None)
    ap.add_argument('--hgmax', type=parse_triplet, default=None)
    ap.add_argument('--rmin', type=float, default=0.1)
    ap.add_argument('--full-g', action='store_true')
    ap.add_argument('--print-lattice', action='store_true')

    # Born–Mayer（SX-1）
    ap.add_argument('--sx1', type=str, required=True, help='SX-1 parameter file (tab-separated)')
    ap.add_argument('--sr-cut', type=float, default=8.0, help='Short-range cutoff [Å]')
    ap.add_argument('--sr-include-coulomb', action='store_true', help='(Not recommended) add Coulomb in SR too')

    # 出力
    ap.add_argument('--all-sites', action='store_true')
    return ap

def main():
    ap = build_argparser()
    args = ap.parse_args()

    # 構造
    lattice_parameters, sites = cif_to_lattice_and_sites(
        args.cif, charge_map=args.charge_map, zero_charge=args.zero_charge
    )

    # --- apply lattice scaling ---
    if args.scale_abc:
        sa, sb, sc = parse_abc_scale(args.scale_abc)
    else:
        sa = sb = sc = args.scale

    # a,b,c に倍率を掛ける（角度は据え置き）
    lattice_parameters[0] *= sa  # a
    lattice_parameters[1] *= sb  # b
    lattice_parameters[2] *= sc  # c

    if args.print_lattice:
        lat_tmp = build_lattice_info(lattice_parameters)
        print(f"\nApplied lattice scale: a×{sa}, b×{sb}, c×{sc}")
        print(f"Scaled lattice: {lattice_parameters}  Volume[Å^3]: {lat_tmp['volume']:.6f}")

    # Ewald 設定
    auto_alpha = args.auto and (args.alpha is None)
    alpha = (args.alpha if args.alpha is not None else 0.3)
    common_kwargs = dict(
        alpha=alpha, prec=args.prec, rmin=args.rmin,
        nrmax_override=args.nrmax, hgmax_override=args.hgmax,
        halfspace=(not args.full_g), auto_alpha=auto_alpha, alpha_grid=args.alpha_grid
    )

    # Born–Mayer DB
    sx1_db = read_SX1_potential(args.sx1)

    # まず Ewald を全サイトで計算（力・エネルギー per-site）
    ewald_results = [ewald_potential_force_at_site(lattice_parameters, sites, i_index=i, **common_kwargs)
                     for i in range(len(sites))]
    if args.print_lattice and ewald_results:
        lat = ewald_results[0]['lattice']
        aij, gij, vol = lat['aij'], lat['gij'], lat['volume']
        print("\nLattice:", lattice_parameters, " Volume[Å^3]:", vol)

    # Born–Mayer（短距離）を一括評価（per-site）
    E_bm_eV, F_bm_N = compute_short_range_born_mayer(
        lattice_parameters, sites, sx1_db, sr_cut=args.sr_cut, include_coulomb=args.sr_include_coulomb
    )

    # 表示
    if args.all_sites:
        for i in range(len(sites)):
            print_site_totals(i, sites, ewald_results[i], E_bm_eV[i], F_bm_N[i], show_forces=True)
    else:
        i = 0
        print_site_totals(i, sites, ewald_results[i], E_bm_eV[i], F_bm_N[i], show_forces=True)

    # --- 既存の per-site 計算・表示のあとに追加してください ---

    # 1) Ewald の全エネルギー（J）
    #    0.5 * Σ_i q_i Ke φ_i  （φ_i = res['MP']）
    E_ew_tot_J = 0.0
    for i, res in enumerate(ewald_results):
        qi = sites[i][4]
        E_ew_tot_J += 0.5 * qi * Ke * res['MP']

    # 2) Born–Mayer の全エネルギー（J）
    #    per-siteに1/2配分済みなので Σ_i E_i を eV→J 変換
    E_bm_tot_J = np.sum(E_bm_eV) * J_per_eV

    # 3) 合計
    E_tot_J  = E_ew_tot_J + E_bm_tot_J
    E_ew_tot_eV = E_ew_tot_J * eV_per_J
    E_bm_tot_eV = E_bm_tot_J * eV_per_J
    E_tot_eV    = E_tot_J    * eV_per_J

    # 4) 検算：全力の合計（理想的には ~0）
    F_sum_N = np.sum([res['F_tot'] for res in ewald_results], axis=0) + np.sum(F_bm_N, axis=0)

    print("\n=== CELL TOTALS ===")
    print(f"  Ewald total    : {E_ew_tot_J: .6e} J  ({E_ew_tot_eV: .6e} eV)")
    print(f"  Born–Mayer tot.: {E_bm_tot_J: .6e} J  ({E_bm_tot_eV: .6e} eV)")
    print(f"  TOTAL energy   : {E_tot_J: .6e} J  ({E_tot_eV: .6e} eV)")
    print(f"  Sum of forces  : [{F_sum_N[0]: .6e}, {F_sum_N[1]: .6e}, {F_sum_N[2]: .6e}] N")

if __name__ == '__main__':
    main()
