#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MD main: pymatgenでCIFを読み込み、Ewald(C) + Born–Mayer(SX-1, Python)でMDを実行。
- 依存: numpy, ewaldf_BM.py, ewald_c_ctypes.py（ctypesラッパ）
- Ewaldは ./ewald_c.dll を既定でロード（--ewald-dll で変更可）。失敗時はPython実装に自動フォールバック。
- 単位: 位置[Å], 速度[Å/fs], 力[eV/Å], 質量[eV·fs^2/Å^2], 時間[fs]

例:
python md.py --cif MgO.cif --sx1 SX-1.txt --steps 2000 --dt 1e-15 \
  --temperature 300 --thermostat berendsen --tau 1000 \
  --sr-cut 8.0 --prec 1e-5 --auto --alpha-grid 40 --save traj.xyz --print-breakdown
"""

import argparse
import sys
import os
import math
import numpy as np

# --- 自作ライブラリ ---
import ewaldf_BM as ebm  # Born–Mayer, CIF読込, 格子情報, α設計 等

# ctypesラッパ（存在しない場合は後でフォールバック）
try:
    import ewald_c_ctypes as ewc
    HAS_CTYPES_WRAPPER = True
except Exception:
    ewc = None
    HAS_CTYPES_WRAPPER = False

# ===================== 単位・ユーティリティ =====================
kB_eV = 8.617333262145e-5          # eV/K
amu_to_eVA2_fs2 = 103.6426965      # 1 amu = 103.6426965 eV·fs^2/Å^2
N_to_eV_per_A = ebm.eV_per_J * ebm.ANG2M  # N → eV/Å

def frac_to_cart(aij, frac):
    return (aij.T @ frac.T).T

def cart_to_frac(aij, cart):
    return (np.linalg.solve(aij.T, cart.T)).T

def wrap_frac(frac):
    return frac - np.floor(frac)

def instantaneous_temperature(mass_amu, velocities):
    m = mass_amu[:, None] * amu_to_eVA2_fs2
    Ekin = 0.5 * float(np.sum(m * velocities**2))
    ndof = 3 * len(mass_amu)
    return (2.0 * Ekin) / (ndof * kB_eV)

def kinetic_energy(mass_amu, velocities):
    m = mass_amu[:, None] * amu_to_eVA2_fs2
    return 0.5 * float(np.sum(m * velocities**2))

def maxwell_boltzmann_velocities(mass_amu, T):
    m_eVA2_fs2 = mass_amu * amu_to_eVA2_fs2
    sigma = np.sqrt(kB_eV * T / m_eVA2_fs2)
    v = np.random.normal(size=(len(mass_amu), 3))
    v *= sigma[:, None]
    # 全運動量ゼロ
    mcol = m_eVA2_fs2[:, None]
    vcm = np.sum(mcol * v, axis=0) / np.sum(mcol)
    v -= vcm[None, :]
    return v

def berendsen_scale(vel, mass_amu, T_target, tau_fs, dt_fs):
    if tau_fs <= 0:
        return vel
    T_inst = instantaneous_temperature(mass_amu, vel)
    if T_inst <= 1e-30:
        return vel
    c = math.sqrt(1.0 + (dt_fs / tau_fs) * (T_target / T_inst - 1.0))
    return vel * c

def write_xyz(path, positions, symbols, comment, append=False):
    mode = "a" if append else "w"
    with open(path, mode, encoding="utf-8") as f:
        n = len(symbols)
        f.write(f"{n}\n{comment}\n")
        for s, (x, y, z) in zip(symbols, positions):
            f.write(f"{s:2s} {x:16.8f} {y:16.8f} {z:16.8f}\n")

# ===================== Ewald + Born–Mayer エンジン =====================
class EwaldSX1Engine:
    """
    Ewald: ctypes経由で C DLL（ewald_c.dll）を使った全サイト一括計算（失敗時はPython実装にフォールバック）
    Born–Mayer: ewaldf_BM.compute_short_range_born_mayer（per-site エネルギーと力[N]）
    """
    def __init__(self, lattice_parameters, sites, sx1_db,
                 prec=1e-5, alpha=None, auto=False, alpha_grid=40,
                 nrmax=None, hgmax=None, rmin=0.1, halfspace=True,
                 sr_cut=8.0, sr_include_coulomb=False,
                 use_ctypes=True, dll_path="./ewald_c.dll"):
        self.lattice_parameters = list(lattice_parameters)
        self.sites = sites
        self.sx1_db = sx1_db

        self.prec = float(prec)
        self.alpha = (0.3 if alpha is None else float(alpha))
        self.auto = bool(auto)
        self.alpha_grid = int(alpha_grid)
        self.nrmax = nrmax
        self.hgmax = hgmax
        self.rmin = float(rmin)
        self.halfspace = bool(halfspace)

        self.sr_cut = float(sr_cut)
        self.sr_include_coulomb = bool(sr_include_coulomb)

        # DLLロード（必要なら）
        self.use_ctypes = bool(use_ctypes) and HAS_CTYPES_WRAPPER
        self.lib = None
        if self.use_ctypes:
            try:
                # "./ewald_c.dll" のような相対パス指定を許容
                self.lib = ewc.load_library(dll_path)
            except Exception as e:
                print(f"# [notice] ctypes DLL を読み込めませんでした: {e}\n# => Python実装にフォールバックします。", file=sys.stderr)
                self.use_ctypes = False

        # 格子情報とカットオフを1回だけ決定
        self.lat = ebm.build_lattice_info(self.lattice_parameters)
        self.aij = np.asarray(self.lat['aij'], float)
        if self.auto:
            # サイト数に基づき α/カット設計
            self.alpha, (_, self.nrmax, _, self.hgmax) = ebm.auto_choose_alpha(
                self.lattice_parameters, self.lat, prec=self.prec, nsites=len(self.sites), n_grid=self.alpha_grid
            )
        else:
            _, self.nrmax, _, self.hgmax = ebm.estimate_cutoffs(
                self.lattice_parameters, self.lat, self.alpha, self.prec
            )
        if self.nrmax is None or self.hgmax is None:
            raise RuntimeError("nrmax/hgmax が決定できませんでした。")

        # 以降 Ewald の設定は固定（MD中はセル固定想定）
        self._common_kwargs = dict(
            alpha=self.alpha, prec=self.prec, rmin=self.rmin,
            nrmax_override=self.nrmax, hgmax_override=self.hgmax,
            halfspace=self.halfspace, auto_alpha=False, alpha_grid=self.alpha_grid
        )

    def update_positions_cart(self, positions_cart):
        """MDで更新されたCartesian座標[Å]をsitesの分率座標に反映してwrap。"""
        frac = cart_to_frac(self.aij, positions_cart)
        frac = wrap_frac(frac)
        for i in range(len(self.sites)):
            self.sites[i][7] = frac[i]

    def _ewald_ctypes(self):
        """ctypes経由でEwald（全サイト一括）: φ_i[1/m], F_i[N] を返す"""
        aij  = np.asarray(self.lat['aij'], float)
        Raij = np.asarray(self.lat['Raij'], float)
        gij  = np.asarray(self.lat['gij'], float)
        Rgij = np.asarray(self.lat['Rgij'], float)
        vol  = float(self.lat['volume'])
        frac = np.vstack([s[7] for s in self.sites]).astype(float)
        q    = np.array([s[4] for s in self.sites], float)
        nr   = np.array(self.nrmax, int)
        hg   = np.array(self.hgmax, int)

        mp, F_N = ewc.ewald_all_sites(self.lib, aij, Raij, gij, Rgij, vol,
                                      frac, q, float(self.alpha), nr, hg,
                                      rmin_A=float(self.rmin), halfspace=self.halfspace)
        return mp, F_N

    def _ewald_python(self):
        """Python実装でEwald（全サイト分）: φ_i[1/m], F_i[N] を返す（フォールバック）"""
        res_list = [ebm.ewald_potential_force_at_site(
                        self.lattice_parameters, self.sites, i_index=i, **self._common_kwargs
                    ) for i in range(len(self.sites))]
        mp = np.array([r['MP'] for r in res_list], float)
        F  = np.vstack([r['F_tot'] for r in res_list])
        return mp, F

    def compute_energy_forces(self):
        """
        戻り:
          E_pot_eV: float
          forces_eV_per_A: (N,3) ndarray
          breakdown: dict（Ewald/BM内訳と力の合計N; 任意ログ用）
        """
        # --- Ewald（C or Python）---
        if self.use_ctypes and (self.lib is not None):
            mp, F_ew_N = self._ewald_ctypes()
        else:
            mp, F_ew_N = self._ewald_python()

        # Ewald 全エネルギー（eV）: 0.5 * Σ_i q_i Ke φ_i
        q = np.array([s[4] for s in self.sites], float)
        E_ew_eV = 0.5 * float(np.sum(q * ebm.Ke * mp)) * ebm.eV_per_J

        # --- Born–Mayer（短距離, Python）---
        E_bm_eV_site, F_bm_N = ebm.compute_short_range_born_mayer(
            self.lattice_parameters, self.sites, self.sx1_db,
            sr_cut=self.sr_cut, include_coulomb=self.sr_include_coulomb
        )
        E_bm_eV = float(np.sum(E_bm_eV_site))

        # --- 合成 ---
        F_tot_N = F_ew_N + F_bm_N
        F_tot_eV_A = F_tot_N * N_to_eV_per_A
        E_tot_eV = E_ew_eV + E_bm_eV

        breakdown = dict(E_ew_eV=E_ew_eV, E_bm_eV=E_bm_eV, F_sum_N=np.sum(F_tot_N, axis=0))
        return E_tot_eV, F_tot_eV_A, breakdown

# ===================== CLI =====================
def parse_triplet(s):
    parts = [p.strip() for p in s.split(',')]
    if len(parts)!=3: raise argparse.ArgumentTypeError("must be 'n1,n2,n3'")
    out=[]
    for p in parts:
        if p=='' or p=='0': out.append(None)
        else:
            v=int(p)
            if v<0: raise argparse.ArgumentTypeError("components must be >=0")
            out.append(v)
    return tuple(out)

def parse_abc_scale(s):
    parts = [p.strip() for p in s.split(',')]
    if len(parts) != 3:
        raise argparse.ArgumentTypeError("scale-abc must be 'sa,sb,sc'")
    try:
        return tuple(float(x) for x in parts)
    except Exception:
        raise argparse.ArgumentTypeError("scale-abc components must be float")

def build_argparser():
    p = argparse.ArgumentParser(description="MD（Ewald(C DLL) + SX-1 Born–Mayer, pymatgen読込）")
    # 入力
    p.add_argument("--cif", required=True, help="入力 CIF（pymatgenで読込）")
    p.add_argument("--sx1", required=True, help="SX-1 パラメータ TSV（element mass charge ai bi ci ARAD）")
    p.add_argument("--use-cif-charges", action="store_true",
                   help="CIFの電荷(oxi_state/charge_map)を使用（既定: SX-1のchargeで上書き）")

    # 格子スケール
    p.add_argument("--scale", type=float, default=1.0, help="a,b,c 一様スケール")
    p.add_argument("--scale-abc", type=str, default=None, help="a,b,c 異方スケール 'sa,sb,sc'")

    # Ewald
    p.add_argument("--prec", type=float, default=1e-5, help="Ewald 近似精度")
    p.add_argument("--alpha", type=float, default=None, help="Ewald α [1/Å]（未指定は 0.3 or --auto）")
    p.add_argument("--auto", action="store_true", help="α/カットオフ自動選択（auto_choose_alpha）")
    p.add_argument("--alpha-grid", type=int, default=40)
    p.add_argument("--nrmax", type=parse_triplet, default=None, help="実空間反復 'nx,ny,nz'")
    p.add_argument("--hgmax", type=parse_triplet, default=None, help="逆格子反復 'hx,hy,hz'")
    p.add_argument("--rmin", type=float, default=0.1, help="近接スキップ閾値[Å]")
    p.add_argument("--full-g", action="store_true", help="G 空間で半空間ではなく全空間")

    # DLL
    p.add_argument("--ewald-dll", default="./ewald_c.dll", help="Ewald C DLL のパス（既定: ./ewald_c.dll）")
    p.add_argument("--no-ctypes", action="store_true", help="ctypes（C DLL）を使わずPython実装でEwaldを計算")

    # Born–Mayer（短距離）
    p.add_argument("--sr-cut", type=float, default=8.0, help="Born–Mayer 近距離カット[Å]")
    p.add_argument("--sr-include-coulomb", action="store_true", help="短距離にも Coulomb を加える（通常は不要）")

    # MD
    p.add_argument("--dt", type=float, default=1e-15, help="タイムステップ[s]")
    p.add_argument("--steps", type=int, default=10000, help="ステップ数")
    p.add_argument("--temperature", type=float, default=300.0, help="目標温度[K]")
    p.add_argument("--thermostat", default="none", choices=["none", "berendsen"])
    p.add_argument("--tau", type=float, default=1000.0, help="Berendsen の時定数[fs]")
    p.add_argument("--init-vel", default="mb", choices=["mb","zero"])
    p.add_argument("--init-perturb", type=float, default=0.01, help="初期微小乱数変位[Å]")

    # 出力
    p.add_argument("--save", default=None, help="XYZ トラジェクトリ出力ファイル")
    p.add_argument("--save-interval", type=int, default=50)
    p.add_argument("--print-interval", type=int, default=10)
    p.add_argument("--print-breakdown", action="store_true", help="Ewald/BM内訳と力合計をログ表示")
    return p

# ===================== メイン =====================
def main():
    args = build_argparser().parse_args()

    # --- 構造読み込み（pymatgen via ewaldf_BM） ---
    lattice_parameters, sites = ebm.cif_to_lattice_and_sites(args.cif)

    # 格子スケール
    if args.scale_abc:
        sa, sb, sc = parse_abc_scale(args.scale_abc)
    else:
        sa = sb = sc = float(args.scale)
    lattice_parameters[0] *= sa
    lattice_parameters[1] *= sb
    lattice_parameters[2] *= sc

    # SX-1 読み込み
    sx1_db = ebm.read_SX1_potential(args.sx1)

    # 既定：サイト電荷をSX-1のchargeで上書き（CoulombとBMの整合のため）
    if not args.use_cif_charges:
        for i in range(len(sites)):
            elem = sites[i][0]
            if elem in sx1_db:
                sites[i][4] = float(sx1_db[elem]['charge'])
            else:
                raise SystemExit(f"SX-1に {elem} の行がありません。ファイル: {args.sx1}")

    # 初期座標（分率→Cartesian）
    lat = ebm.build_lattice_info(lattice_parameters)
    aij = np.asarray(lat['aij'], float)
    frac0 = np.vstack([s[7] for s in sites])
    pos = frac_to_cart(aij, frac0)

    # 初期微小変位
    if args.init_perturb > 0.0:
        pos += np.random.normal(scale=args.init_perturb, size=pos.shape)
        f = wrap_frac(cart_to_frac(aij, pos))
        pos = frac_to_cart(aij, f)

    symbols = [s[0] for s in sites]
    mass_amu = np.array([float(s[3]) for s in sites], float)

    # 初期速度
    if args.init_vel == "mb":
        vel = maxwell_boltzmann_velocities(mass_amu, args.temperature)
    else:
        vel = np.zeros_like(pos)

    # Ewald+BM エンジン
    engine = EwaldSX1Engine(
        lattice_parameters, sites, sx1_db,
        prec=args.prec, alpha=args.alpha, auto=args.auto, alpha_grid=args.alpha_grid,
        nrmax=args.nrmax, hgmax=args.hgmax, rmin=args.rmin, halfspace=(not args.full_g),
        sr_cut=args.sr_cut, sr_include_coulomb=args.sr_include_coulomb,
        use_ctypes=(not args.no_ctypes), dll_path=args.ewald_dll
    )

    # 初期エネルギー・力
    engine.update_positions_cart(pos)
    Epot, F_eV_A, br = engine.compute_energy_forces()

    # ログ（初期）
    dt_fs = float(args.dt) / 1e-15
    T0 = instantaneous_temperature(mass_amu, vel)
    Ekin0 = kinetic_energy(mass_amu, vel)
    print("# Initial state")
    print(f"# T0   = {T0:.2f} K")
    print(f"# Epot0= {Epot:.6f} eV")
    print(f"# Ekin0= {Ekin0:.6f} eV")
    print("#\n# step  time[ps]   E_pot[eV]   T[K]")

    # 保存
    saved = 0
    if args.save:
        write_xyz(args.save, pos, symbols, comment="step=0", append=False)
        saved = 1

    # 逆質量
    inv_m = 1.0 / (mass_amu[:, None] * amu_to_eVA2_fs2)

    # MD ループ
    for step in range(1, int(args.steps)+1):
        # v(t+dt/2)
        vel += 0.5 * dt_fs * inv_m * F_eV_A
        # x(t+dt)
        pos += dt_fs * vel
        # wrap
        frac = wrap_frac(cart_to_frac(aij, pos))
        pos = frac_to_cart(aij, frac)

        # 力更新
        engine.update_positions_cart(pos)
        Epot, F_eV_A, br = engine.compute_energy_forces()

        # v(t+dt)
        vel += 0.5 * dt_fs * inv_m * F_eV_A

        # サーモスタット
        if args.thermostat == "berendsen":
            vel = berendsen_scale(vel, mass_amu, args.temperature, tau_fs=float(args.tau), dt_fs=dt_fs)

        # ログ
        if (step % args.print_interval) == 0 or step == 1:
            t_ps = step * dt_fs * 1e-3
            Tnow = instantaneous_temperature(mass_amu, vel)
            print(f"{step:6d}  {t_ps:8.3f}  {Epot:11.6f}  {Tnow:8.2f}")
            if args.print_breakdown:
                Fsum = br['F_sum_N']
                print("         -> E_ew={:.6f} eV, E_bm={:.6f} eV, ΣF=[{:+.3e},{:+.3e},{:+.3e}] N"
                      .format(br['E_ew_eV'], br['E_bm_eV'], Fsum[0], Fsum[1], Fsum[2]))

        # 保存
        if args.save and (step % args.save_interval == 0):
            write_xyz(args.save, pos, symbols, comment=f"step={step}", append=True)
            saved += 1

    if args.save:
        print(f"# Wrote {saved} frames to {args.save}")

if __name__ == "__main__":
    main()
