#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CIF を読み込み、Ewald 法 + Born–Mayer（SX-1）で分子動力学を行う最小実装。
- 単位系: 位置[Å], 速度[Å/fs], 力[eV/Å], 質量[eV·fs^2/Å^2], 時間[fs]
- 積分: Velocity Verlet、Berendsen サーモスタット(任意)
- 出力: XYZ トラジェクトリ

機能ポイント:
- SX-1 パラメータ（Coulomb + Born–Mayer）対応
- Coulomb は Ewald（実空間 + 逆格子 + 自己項）
- --ewald-auto <err[eV/atom]> で α・Gmax・cutoff を自動設計
- 近接が空になるのを防ぐための safeguard として、cutoff の最終下限を **4.5 Å** に固定
- 初期構造に微小乱数変位（デフォルト 0.01 Å）
- 初期 T/Epot/Ekin を表示
- Reciprocal cell は atoms.cell.reciprocal()（Deprecation 対応）

依存: pip install pymatgen numpy

python md_anim.py --cif MgO.cif --sx1 SX-1.txt --steps 1000 --dt 1e-15 --temperature 300 --thermostat berendsen --tau 1000
python md_anim.py --cif MgO.cif --sx1 SX-1.txt --steps 2000 --dt 1e-15 --temperature 300 --thermostat berendsen --tau 1000 --auto --sr-cut 8.0 --prec 1e-5 --anim --yield-every 10 --anim-interval 40
"""

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CIF を pymatgen で読み込み、ewaldf_BM をライブラリとして用いて
Ewald（Coulomb）+ Born–Mayer（SX-1）で MD (Velocity Verlet) を行う最小実装。

- 依存: numpy, pymatgen, (あなたの) ewaldf_BM.py
- ASE には依存しません（CIF読込も近接探索も不要。ewaldf_BM 側で格子和を評価）

例:
python md.py --cif MgO.cif --sx1 SX-1.txt --steps 2000 --dt 1e-15 \
  --temperature 300 --thermostat berendsen --tau 1000 \
  --sr-cut 8.0 --prec 1e-5 --auto --alpha-grid 40 --save traj.xyz
"""

import argparse
import sys
import math
import numpy as np
import matplotlib
matplotlib.use("Qt5Agg")
from mpl_toolkits.mplot3d import Axes3D


# ==== ewaldf_BM ライブラリ ====
import ewaldf_BM as ebm
# 使う主関数:
# - cif_to_lattice_and_sites(path, charge_map=None, zero_charge=False)
# - build_lattice_info(lattice_parameters)
# - ewald_potential_force_at_site(lattice_parameters, sites, i_index, ...)
# - compute_short_range_born_mayer(lattice_parameters, sites, sx1_db, sr_cut, include_coulomb)
# - read_SX1_potential(dbpath)
# 物理定数（単位変換）:
# - ANG2M, eV_per_J, J_per_eV, Ke

# ==== 単位定数（Å/fs 系） ====
kB_eV = 8.617333262145e-5              # eV/K
amu_to_eVA2_fs2 = 103.6426965          # 1 amu = 103.6426965 eV·fs^2/Å^2
N_to_eV_per_A = ebm.eV_per_J * ebm.ANG2M  # N → eV/Å の換算

# ------------------------------------------------------------
# 便利関数
# ------------------------------------------------------------
def frac_to_cart(aij, frac):
    """aij: (3,3)（ewaldf_BMの build 返却の aij） / frac: (...,3)"""
    return (aij.T @ frac.T).T

def cart_to_frac(aij, cart):
    """逆変換"""
    return (np.linalg.solve(aij.T, cart.T)).T

def wrap_frac(frac):
    """[0,1) に折り畳み"""
    return frac - np.floor(frac)

def instantaneous_temperature(mass_amu, velocities):
    m = mass_amu[:, None] * amu_to_eVA2_fs2
    Ekin = 0.5 * float(np.sum(m * velocities**2))   # eV
    ndof = 3 * len(mass_amu)
    return (2.0 * Ekin) / (ndof * kB_eV)

def kinetic_energy(mass_amu, velocities):
    m = mass_amu[:, None] * amu_to_eVA2_fs2
    return 0.5 * float(np.sum(m * velocities**2))   # eV

def maxwell_boltzmann_velocities(mass_amu, T):
    m_eVA2_fs2 = mass_amu * amu_to_eVA2_fs2
    sigma = np.sqrt(kB_eV * T / m_eVA2_fs2)    # (N,)
    v = np.random.normal(size=(len(mass_amu), 3))
    v *= sigma[:, None]                        # Å/fs
    # 全運動量ゼロ化
    mcol = m_eVA2_fs2[:, None]
    vcm = np.sum(mcol * v, axis=0) / np.sum(mcol)
    v -= vcm[None, :]
    return v

def berendsen_scale(vel, mass_amu, T_target, tau_fs, dt_fs):
    if tau_fs <= 0:
        return vel
    T_inst = instantaneous_temperature(mass_amu, vel)
    if T_inst <= 1e-30:
        return vel
    c = math.sqrt(1.0 + (dt_fs / tau_fs) * (T_target / T_inst - 1.0))
    return vel * c

def write_xyz(path, positions, symbols, comment, append=False):
    mode = "a" if append else "w"
    with open(path, mode, encoding="utf-8") as f:
        n = len(symbols)
        f.write(f"{n}\n{comment}\n")
        for s, (x,y,z) in zip(symbols, positions):
            f.write(f"{s:2s} {x:16.8f} {y:16.8f} {z:16.8f}\n")

# ===============================
# 1) 1ステップだけ進める小関数
# ===============================
def md_step(pos, vel, mass_amu, inv_m, dt_fs, engine, thermostat, T_target, tau_fs):
    # 半ステップ速度 → 位置更新 → wrap → 力/エネルギー再計算 → 半ステップ速度
    Epot, F_eV_A, br = engine.compute_energy_forces()  # 直前位置の力
    vel += 0.5 * dt_fs * inv_m * F_eV_A
    pos += dt_fs * vel

    lat = ebm.build_lattice_info(engine.lattice_parameters)
    aij = np.asarray(lat['aij'], float)
    frac = wrap_frac(cart_to_frac(aij, pos))
    pos = frac_to_cart(aij, frac)

    engine.update_positions_cart(aij, pos)
    Epot, F_eV_A, br = engine.compute_energy_forces()

    vel += 0.5 * dt_fs * inv_m * F_eV_A

    # サーモスタット
    if thermostat == "berendsen":
        vel = berendsen_scale(vel, mass_amu, T_target, tau_fs=tau_fs, dt_fs=dt_fs)

    # 観測量
    Tnow  = instantaneous_temperature(mass_amu, vel)
    Ekin  = kinetic_energy(mass_amu, vel)
    return pos, vel, F_eV_A, Epot, Ekin, Tnow, br

# ========================================
# 2) nステップごとにまとめて返すgenerator
# ========================================
def md_run(pos, vel, mass_amu, engine, *,
           dt_fs, nsteps, yield_every=10,
           thermostat="none", T_target=300.0, tau_fs=1000.0):
    inv_m = 1.0 / (mass_amu[:, None] * amu_to_eVA2_fs2)
    t_ps = 0.0

    # 初期状態（描画初期化用に一度返すと便利）
    Epot, F, br = engine.compute_energy_forces()
    Ekin = kinetic_energy(mass_amu, vel)
    Tnow = instantaneous_temperature(mass_amu, vel)
    yield dict(step=0, t_ps=0.0,
               pos=pos.copy(), vel=vel.copy(), force=F.copy(),
               Epot=Epot, Ekin=Ekin, T=Tnow, breakdown=br)

    for step in range(1, nsteps + 1):
        pos, vel, F, Epot, Ekin, Tnow, br = md_step(
            pos, vel, mass_amu, inv_m, dt_fs, engine, thermostat, T_target, tau_fs
        )
        t_ps += dt_fs * 1e-3

        if (step % yield_every) == 0:
            # 渡す直前にコピー（描画側が保持しても安全）
            yield dict(step=step, t_ps=t_ps,
                       pos=pos.copy(), vel=vel.copy(), force=F.copy(),
                       Epot=Epot, Ekin=Ekin, T=Tnow, breakdown=br)

# =======================================
# 3) matplotlibアニメ：xy投影の散布図例
# =======================================
def run_live_animation_2d(frames_gen, symbols, cell=None, interval_ms=50):
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation

    # 初回フレームで点群サイズや範囲決め
    first = next(frames_gen)
    pos = first["pos"]
    x, y = pos[:,0], pos[:,1]

    fig, ax = plt.subplots()
    sc = ax.scatter(x, y, s=20)  # 最初は小さめ

    # 軸範囲（適宜セルや±marginから決める）
    margin = 0.1 * max((x.max()-x.min()), (y.max()-y.min()), 1.0)
    ax.set_xlim(x.min()-margin, x.max()+margin)
    ax.set_ylim(y.min()-margin, y.max()+margin)
    ax.set_aspect('equal', adjustable='box')
    title = ax.set_title(f"step={first['step']}  t={first['t_ps']:.3f} ps  T={first['T']:.1f} K")

    # ラベル色分けしたい場合は、elements→色mapを別途用意

    def update(frame):
        pos = frame["pos"]
        sc.set_offsets(pos[:, :2])  # xy投影
        title.set_text(f"step={frame['step']}  t={frame['t_ps']:.3f} ps  T={frame['T']:.1f} K")
        return (sc, title)

    ani = FuncAnimation(fig, update, frames=frames_gen,
                        interval=interval_ms, blit=False)
    plt.show()

# =======================================
# 3) matplotlibアニメ：3D散布図版
# =======================================
def run_live_animation(frames_gen, symbols, cell=None, interval_ms=50):
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401 (確実に3Dを有効化)

    # 初回フレームで初期化
    first = next(frames_gen)
    pos0 = first["pos"]
    x0, y0, z0 = pos0[:, 0], pos0[:, 1], pos0[:, 2]

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # 3D散布図（色やサイズを元素に応じて変える拡張は後で可能）
    sc = ax.scatter(x0, y0, z0, s=20, depthshade=True)

    # 軸範囲（初期フレームから固定。PBCならセルから決めてもOK）
    xmin, xmax = float(x0.min()), float(x0.max())
    ymin, ymax = float(y0.min()), float(y0.max())
    zmin, zmax = float(z0.min()), float(z0.max())
    # スケールのゼロ割れ回避
    rx = max(xmax - xmin, 1.0)
    ry = max(ymax - ymin, 1.0)
    rz = max(zmax - zmin, 1.0)
    margin = 0.1 * max(rx, ry, rz)

    ax.set_xlim(xmin - margin, xmax + margin)
    ax.set_ylim(ymin - margin, ymax + margin)
    ax.set_zlim(zmin - margin, zmax + margin)

    # 3D等方比（Matplotlib>=3.3）
    try:
        ax.set_box_aspect((rx + 2*margin, ry + 2*margin, rz + 2*margin))
    except Exception:
        pass

    title = ax.set_title(f"step={first['step']}  t={first['t_ps']:.3f} ps  T={first['T']:.1f} K")

    def update(frame):
        pos = frame["pos"]
        # 3D scatterの更新は _offsets3d を使う
        sc._offsets3d = (pos[:, 0], pos[:, 1], pos[:, 2])
        title.set_text(f"step={frame['step']}  t={frame['t_ps']:.3f} ps  T={frame['T']:.1f} K")
        # blit=True は3Dで効かないので False のまま
        return (sc, title)

    ani = FuncAnimation(fig, update, frames=frames_gen,
                        interval=interval_ms, blit=False)
    plt.show()

# ------------------------------------------------------------
# 力・エネルギー計算（ewaldf_BM を利用）
# ------------------------------------------------------------
class EwaldSX1Engine:
    """
    ewaldf_BM を呼び出して
      - Ewald(Coulomb): per-site 力[N], φ[1/m] → E = 0.5 Σ_i q_i Ke φ_i
      - Born–Mayer(SX-1): per-site エネルギー[eV], 力[N]
    を合成し、全エネルギー[eV] と 力[eV/Å] を返す。
    """
    def __init__(self, lattice_parameters, sites, sx1_db,
                 prec=1e-5, alpha=None, auto=False, alpha_grid=40,
                 nrmax=None, hgmax=None, rmin=0.1, halfspace=True,
                 sr_cut=8.0, sr_include_coulomb=False):
        self.lattice_parameters = list(lattice_parameters)  # [a,b,c,α,β,γ]
        self.sites = sites
        self.sx1_db = sx1_db

        self.prec = prec
        self.alpha = alpha if alpha is not None else 0.3
        self.auto = bool(auto)
        self.alpha_grid = int(alpha_grid)
        self.nrmax = nrmax
        self.hgmax = hgmax
        self.rmin = float(rmin)
        self.halfspace = bool(halfspace)

        self.sr_cut = float(sr_cut)
        self.sr_include_coulomb = bool(sr_include_coulomb)

        # 1回だけ α/カットを決め、以降は override で固定（セル一定の想定）
        self._common_kwargs = None
        self._prepare_cutoffs()

    def _prepare_cutoffs(self):
        # サイト0を使って一度だけ self.cutoffs を決定
        res0 = ebm.ewald_potential_force_at_site(
            self.lattice_parameters, self.sites, i_index=0,
            alpha=self.alpha, prec=self.prec, rmin=self.rmin,
            nrmax_override=self.nrmax, hgmax_override=self.hgmax,
            halfspace=self.halfspace, auto_alpha=self.auto, alpha_grid=self.alpha_grid
        )
        self.alpha = res0['alpha']
        self.cut_nrmax = res0['cutoffs']['nrmax']
        self.cut_hgmax = res0['cutoffs']['hgmax']
        self._common_kwargs = dict(
            alpha=self.alpha, prec=self.prec, rmin=self.rmin,
            nrmax_override=self.cut_nrmax, hgmax_override=self.cut_hgmax,
            halfspace=self.halfspace, auto_alpha=False, alpha_grid=self.alpha_grid
        )

    def update_positions_cart(self, aij, positions_cart):
        """
        MD で更新された Cartesian 座標[Å]を sites の frac に反映（PBC wrap）。
        """
        frac = cart_to_frac(aij, positions_cart)
        frac = wrap_frac(frac)
        for i in range(len(self.sites)):
            self.sites[i][7] = frac[i]

    def compute_energy_forces(self):
        """
        戻り:
          E_pot_eV: float
          forces_eV_per_A: (N,3) ndarray
          breakdown: dict（任意表示用）
        """
        n = len(self.sites)

        # --- Ewald を全サイトで ---
        res_list = [ebm.ewald_potential_force_at_site(
                        self.lattice_parameters, self.sites, i_index=i, **self._common_kwargs
                    ) for i in range(n)]

        # 力（N）を集約
        F_ew_N = np.vstack([res['F_tot'] for res in res_list])  # (N,3)
        # Ewald 全エネルギー（J）
        E_ew_J = 0.0
        for i, res in enumerate(res_list):
            qi = self.sites[i][4]
            E_ew_J += 0.5 * qi * ebm.Ke * res['MP']
        E_ew_eV = E_ew_J * ebm.eV_per_J

        # --- Born–Mayer（短距離） ---
        E_bm_eV_site, F_bm_N = ebm.compute_short_range_born_mayer(
            self.lattice_parameters, self.sites, self.sx1_db,
            sr_cut=self.sr_cut, include_coulomb=self.sr_include_coulomb
        )
        E_bm_eV = float(np.sum(E_bm_eV_site))

        # --- 合成 ---
        F_tot_N = F_ew_N + F_bm_N
        F_tot_eV_A = F_tot_N * N_to_eV_per_A
        E_tot_eV = E_ew_eV + E_bm_eV

        # 任意のデバッグ用内訳
        breakdown = dict(
            E_ew_eV=E_ew_eV,
            E_bm_eV=E_bm_eV,
            F_sum_N=np.sum(F_tot_N, axis=0)
        )
        return E_tot_eV, F_tot_eV_A, breakdown


# ------------------------------------------------------------
# CLI / メイン
# ------------------------------------------------------------
def parse_triplet(s):
    parts = [p.strip() for p in s.split(',')]
    if len(parts)!=3: raise argparse.ArgumentTypeError("must be 'n1,n2,n3'")
    out=[]
    for p in parts:
        if p=='' or p=='0': out.append(None)
        else:
            v=int(p)
            if v<0: raise argparse.ArgumentTypeError("components must be >=0")
            out.append(v)
    return tuple(out)

def build_argparser():
    p = argparse.ArgumentParser(description="MD（Ewald + SX-1 Born–Mayer, pymatgen/ewaldf_BM版）")
    p.add_argument("--cif", required=True, help="入力 CIF ファイル（pymatgen）")
    p.add_argument("--sx1", required=True, help="SX-1 パラメータ TSV（element mass charge ai bi ci ARAD）")
    p.add_argument("--use-cif-charges", action="store_true",
                   help="サイト電荷をSX-1ではなくCIF(oxi_state/charge_map)由来にする（既定：SX-1で上書き）")

    # 格子スケール（任意）
    p.add_argument("--scale", type=float, default=1.0, help="a,b,c 一様スケール")
    p.add_argument("--scale-abc", type=str, default=None, help="a,b,c 異方スケール 'sa,sb,sc'")

    # Ewald
    p.add_argument("--prec", type=float, default=1e-5, help="Ewald 近似精度（小さいほど厳密）")
    p.add_argument("--alpha", type=float, default=None, help="Ewald α（指定が無い場合は 0.3 or --auto）")
    p.add_argument("--auto", action="store_true", help="α/カットオフ自動設計")
    p.add_argument("--alpha-grid", type=int, default=40)
    p.add_argument("--nrmax", type=parse_triplet, default=None, help="実空間サムの最大反復 'nx,ny,nz'（未指定は自動）")
    p.add_argument("--hgmax", type=parse_triplet, default=None, help="逆格子サムの最大反復 'hx,hy,hz'（未指定は自動）")
    p.add_argument("--rmin", type=float, default=0.1, help="自己/極近接スキップ閾値[Å]")
    p.add_argument("--full-g", action="store_true", help="G空間で半空間ではなく全空間を和（既定は半空間）")

    # Born–Mayer（短距離）
    p.add_argument("--sr-cut", type=float, default=8.0, help="Born–Mayer 近距離カット[Å]")
    p.add_argument("--sr-include-coulomb", action="store_true", help="短距離に Coulomb を追加（通常は不要）")

    # MD
    p.add_argument("--dt", type=float, default=1e-15, help="タイムステップ[s]（既定 1e-15 ≒ 1 fs）")
    p.add_argument("--steps", type=int, default=10000, help="ステップ数")
    p.add_argument("--temperature", type=float, default=300.0, help="目標温度[K]")
    p.add_argument("--thermostat", default="none", choices=["none", "berendsen"])
    p.add_argument("--tau", type=float, default=1000.0, help="Berendsen の時定数[fs]")
    p.add_argument("--init-vel", default="mb", choices=["mb","zero"])
    p.add_argument("--init-perturb", type=float, default=0.01, help="初期微小乱数変位[Å]（0で無効）")

# build_argparser() の出力に追加
#    p.add_argument("--anim", action="store_true", help="matplotlibでリアルタイム可視化（xy投影）")
    p.add_argument("--anim", action="store_true",
               help="matplotlibでリアルタイム3D可視化（散布図）")
    p.add_argument("--anim-interval", type=int, default=50, help="アニメ更新間隔[ms]")
    p.add_argument("--yield-every", type=int, default=10, help="nステップごとに描画へ渡す")

    # 出力
    p.add_argument("--save", default=None, help="XYZ 出力ファイル名")
    p.add_argument("--save-interval", type=int, default=50)
    p.add_argument("--print-interval", type=int, default=10)
    p.add_argument("--print-breakdown", action="store_true", help="E_ew/BM/力合計を各ログに表示")
    return p

def parse_abc_scale(s):
    parts = [p.strip() for p in s.split(',')]
    if len(parts) != 3:
        raise argparse.ArgumentTypeError("scale-abc must be 'sa,sb,sc'")
    try:
        return tuple(float(x) for x in parts)
    except Exception:
        raise argparse.ArgumentTypeError("scale-abc components must be float")

def main():
    args = build_argparser().parse_args()

    # --- 構造読み込み（pymatgen via ewaldf_BM） ---
    lattice_parameters, sites = ebm.cif_to_lattice_and_sites(args.cif)
    # スケール
    if args.scale_abc:
        sa,sb,sc = parse_abc_scale(args.scale_abc)
    else:
        sa = sb = sc = float(args.scale)
    lattice_parameters[0] *= sa
    lattice_parameters[1] *= sb
    lattice_parameters[2] *= sc

    # SX-1 DB 読み込み
    sx1_db = ebm.read_SX1_potential(args.sx1)

    # 既定ではサイト電荷を SX-1 の charge 列で上書き（Coulomb 計算の一貫性確保）
    if not args.use_cif_charges:
        for i in range(len(sites)):
            elem = sites[i][0]
            if elem in sx1_db:
                sites[i][4] = float(sx1_db[elem]['charge'])
            else:
                raise SystemExit(f"SX-1に {elem} の行がありません。ファイル:{args.sx1}")

    # 初期位置（Cartesian Å）/ 元の frac は sites[i][7]
    lat = ebm.build_lattice_info(lattice_parameters)
    aij = np.asarray(lat['aij'], float)
    frac0 = np.vstack([s[7] for s in sites])
    pos = frac_to_cart(aij, frac0)

    # 初期微小変位
    if args.init_perturb > 0:
        pos += np.random.normal(scale=args.init_perturb, size=pos.shape)
        # wrap
        f = wrap_frac(cart_to_frac(aij, pos))
        pos = frac_to_cart(aij, f)

    # 記号、質量
    symbols = [s[0] for s in sites]
    mass_amu = np.array([float(s[3]) for s in sites], dtype=float)  # pymatgen の atomic_mass は amu

    # 初期速度
    if args.init_vel == "mb":
        vel = maxwell_boltzmann_velocities(mass_amu, args.temperature)
    else:
        vel = np.zeros_like(pos)

    # Ewald+BM エンジン
    engine = EwaldSX1Engine(
        lattice_parameters, sites, sx1_db,
        prec=args.prec, alpha=args.alpha, auto=args.auto, alpha_grid=args.alpha_grid,
        nrmax=args.nrmax, hgmax=args.hgmax, rmin=args.rmin, halfspace=(not args.full_g),
        sr_cut=args.sr_cut, sr_include_coulomb=args.sr_include_coulomb
    )

    # 初期エネルギー・力
    engine.update_positions_cart(aij, pos)
    Epot, F_eV_A, br = engine.compute_energy_forces()

    # ログ（初期）
    dt_fs = float(args.dt) / 1e-15
    T0 = instantaneous_temperature(mass_amu, vel)
    Ekin0 = kinetic_energy(mass_amu, vel)
    print("# Initial state")
    print(f"# T0   = {T0:.2f} K")
    print(f"# Epot0= {Epot:.6f} eV")
    print(f"# Ekin0= {Ekin0:.6f} eV")
    print("#\n# step  time[ps]   E_pot[eV]   T[K]")

    # 保存
    if args.save:
        write_xyz(args.save, pos, symbols, comment="step=0", append=False)
        saved = 1
    else:
        saved = 0

    # 逆質量（Å/fs 系）
    inv_m = 1.0 / (mass_amu[:, None] * amu_to_eVA2_fs2)

# main() の末尾、MDループの代わりに
    if args.anim:
        frames = md_run(pos, vel, mass_amu, engine,
                    dt_fs=dt_fs, nsteps=int(args.steps),
                    yield_every=int(args.yield_every),
                    thermostat=args.thermostat,
                    T_target=float(args.temperature),
                    tau_fs=float(args.tau))
        run_live_animation(frames, symbols, cell=None, interval_ms=int(args.anim_interval))
    else:
    # MD ループ
        for step in range(1, int(args.steps)+1):
            # v(t+dt/2)
            vel += 0.5 * dt_fs * inv_m * F_eV_A
            # x(t+dt)
            pos += dt_fs * vel
            # wrap to cell
            frac = wrap_frac(cart_to_frac(aij, pos))
            pos = frac_to_cart(aij, frac)

            # 力を更新（Ewald+BM）
            engine.update_positions_cart(aij, pos)
            Epot, F_eV_A, br = engine.compute_energy_forces()

            # v(t+dt)
            vel += 0.5 * dt_fs * inv_m * F_eV_A

            # サーモスタット
            if args.thermostat == "berendsen":
                vel = berendsen_scale(vel, mass_amu, args.temperature, tau_fs=float(args.tau), dt_fs=dt_fs)

            # ログ
            if (step % args.print_interval) == 0 or step == 1:
                t_ps = step * dt_fs * 1e-3
                Tnow = instantaneous_temperature(mass_amu, vel)
                print(f"{step:6d}  {t_ps:8.3f}  {Epot:11.6f}  {Tnow:8.2f}")
                if args.print_breakdown:
                    Fsum = br['F_sum_N']  # N
                    print("         -> E_ew={:.6f} eV, E_bm={:.6f} eV, ΣF=[{:+.3e},{:+.3e},{:+.3e}] N"
                          .format(br['E_ew_eV'], br['E_bm_eV'], Fsum[0], Fsum[1], Fsum[2]))

            # 保存
            if args.save and (step % args.save_interval == 0):
                write_xyz(args.save, pos, symbols, comment=f"step={step}", append=True)
                saved += 1

    if args.save:
        print(f"# Wrote {saved} frames to {args.save}")


if __name__ == "__main__":
    main()
