tiny_simulations.diffeq2nd_planet_generator のソースコード

# ============================================================
# Planet simulator (NumPy): Euler / Heun / Verlet / velocity-Verlet
# ============================================================
"""
惑星の軌道シミュレーションを実行するモジュール。

本モジュールは、Euler、Heun、Verlet、Velocity-Verlet法などの数値解法を用いて、
複数惑星の運動方程式を解き、その軌跡を計算・描画します。ジェネレータを通じて
シミュレーション状態を順次取得することが可能です。

関連リンク:
:doc:`diffeq2nd_planet_generator_usage`
"""

import sys
import csv
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import signal
import time

matplotlib.use("TkAgg")   # または "QtAgg"


# ===================
# constants
# ===================
G = 6.67259e-11                       # Nm2/kg2
DayToSecond      = 60 * 60 * 24       # s
AstronomicalUnit = 1.49597870e11      # m
AU               = AstronomicalUnit
G1 = G * DayToSecond * DayToSecond / AU / AU / AU

# ===================
# defaults / params
# ===================
solver = 'vverlet'      # 'Euler' | 'Heun' | 'Verlet' | 'vverlet'
fplot = 1
dbfile   = 'planet_db.csv'
dt = 0.1
nt = 20000
iprint_interval    = 100
nprint_planets     = 4
xgrange = (-5.0, 5.0)
ygrange = (-5.0, 5.0)
yield_every = 10
enable_ctrlc = 1

# ===================
# helpers
# ===================
[ドキュメント] def readdb(dbfile): """ CSVファイルから惑星データを読み込む。 指定されたCSVファイルを開き、各行を辞書として読み込みます。 数値データはfloatに変換して返します。 :param dbfile: str, 読み込むCSVファイルのパス。 :returns: list, 惑星データの辞書を格納したリスト。 """ rows = [] with open(dbfile, "r", newline='', encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: rows.append(row) for d in rows: for k in list(d.keys()): if k != 'Name': d[k] = float(d[k]) return rows
[ドキュメント] def normalize_momentum(m, v): """ 系全体の運動量をゼロに正規化する。 質量と速度から系全体の運動量を計算し、重心系での速度に変換するよう 各惑星の速度をインプレースで調整します。 :param m: numpy.ndarray, 惑星の質量の配列 (N,)。 :param v: numpy.ndarray, 惑星の速度の配列 (N, 3)。 :returns: None """ P = (m[:, None] * v).sum(axis=0) v -= P[None, :] / m.sum()
[ドキュメント] def kinetic_energy(m, v): """ 系の全運動エネルギーを計算する。 :param m: numpy.ndarray, 惑星の質量の配列 (N,)。 :param v: numpy.ndarray, 惑星の速度の配列 (N, 3)。 :returns: float, 計算された全運動エネルギー。 """ return 0.5 * (m * (v*v).sum(axis=1)).sum()
[ドキュメント] def pairwise_accel(m, r): """ 万有引力の法則に基づき各惑星の加速度を計算する。 惑星間の相対距離からニュートン力学に基づく重力加速度を計算して返します。 :param m: numpy.ndarray, 惑星の質量の配列 (N,)。 :param r: numpy.ndarray, 惑星の位置の配列 (N, 3)。 :returns: numpy.ndarray, 各惑星の加速度の配列 (N, 3)。 """ D = r[None, :, :] - r[:, None, :] r2 = np.einsum('ijk,ijk->ij', D, D) np.fill_diagonal(r2, np.inf) inv_r3 = 1.0 / (r2 * np.sqrt(r2)) W = (m[None, :] * inv_r3)[:, :, None] a = G1 * np.sum(W * D, axis=1) return a
[ドキュメント] def potential_energy(m, r): """ 系の全ポテンシャルエネルギー(重力位置エネルギー)を計算する。 :param m: numpy.ndarray, 惑星の質量の配列 (N,)。 :param r: numpy.ndarray, 惑星の位置の配列 (N, 3)。 :returns: float, 計算されたポテンシャルエネルギー。 """ D = r[None, :, :] - r[:, None, :] r2 = np.einsum('ijk,ijk->ij', D, D) N = r.shape[0] iu = np.triu_indices(N, k=1) dist = np.sqrt(r2[iu]) mi_mj = (m[iu[0]] * m[iu[1]]) return (G1 * mi_mj / dist).sum()
[ドキュメント] def total_momentum(m, v): """ 系の全運動量とその二乗平均平方根を計算する。 :param m: numpy.ndarray, 惑星の質量の配列 (N,)。 :param v: numpy.ndarray, 惑星の速度の配列 (N, 3)。 :returns: tuple, 運動量のx, y, z成分と、その二乗平均平方根の値 (Px, Py, Pz, Pmsm)。 """ P = (m[:, None] * v).sum(axis=0) Pmsm = np.sqrt((P*P).sum() / (3.0 * max(1, len(m)))) return P[0], P[1], P[2], Pmsm
[ドキュメント] def initialize(planets): """ 惑星データからシミュレーション用の初期状態を生成する。 :param planets: list, 辞書形式の惑星データリスト。 :returns: tuple, 名前(names), 質量(m), 位置(r), 速度(v), 加速度(a)を含むタプル。 """ names = np.array([p['Name'] for p in planets], dtype=object) m = np.array([p['Mass'] for p in planets], dtype=float) r = np.zeros((len(planets), 3), dtype=float) v = np.zeros((len(planets), 3), dtype=float) r[:, 0] = np.array([p['Revolution Radius'] / AU for p in planets]) v[:, 1] = np.array([p['Revolution Velocity'] * DayToSecond / AU for p in planets]) normalize_momentum(m, v) a = pairwise_accel(m, r) return names, m, r, v, a
# =================== # generator # ===================
[ドキュメント] def md_generator(planets, solver, dt, nt, yield_every=1): """ 惑星シミュレーションを実行し、状態を順次返すジェネレータ。 指定された数値解法(Euler, Heun, Verlet, Velocity-Verlet)を用いて 指定回数の時間発展を計算し、定期的なステップ間隔で現在の状態を生成します。 :param planets: list, 辞書形式の惑星データリスト。 :param solver: str, 使用する数値解法名 ('Euler', 'Heun', 'Verlet', 'vverlet')。 :param dt: float, タイムステップ幅。 :param nt: int, 総ステップ数。 :param yield_every: int, 状態をyieldするステップ間隔。 :returns: generator, (時間, 位置, 速度, 加速度, (U, K, E), (Px, Py, Pz, Pmsm)) のタプルを生成。 """ s = (solver or '').lower() if s == 'velet': s = 'vverlet' names, m, r, v, a = initialize(planets) U = potential_energy(m, r) K = kinetic_energy(m, v) E = U + K Px, Py, Pz, Pmsm = total_momentum(m, v) yield (0.0, r, v, a, (U, K, E), (Px, Py, Pz, Pmsm)) if s == 'verlet': r_prev = r - dt * v for it in range(1, nt+1): t = it * dt if s == 'euler': v += dt * a r += dt * v a = pairwise_accel(m, r) elif s == 'heun': v_pred = v + dt * a r_pred = r + dt * v a_pred = pairwise_accel(m, r_pred) v += 0.5 * dt * (a + a_pred) r += 0.5 * dt * (v + v_pred) a = pairwise_accel(m, r) elif s == 'verlet': r_new = 2.0 * r - r_prev + (dt*dt) * a v[:] = (r_new - r_prev) / (2.0 * dt) r_prev, r = r, r_new a = pairwise_accel(m, r) else: v += 0.5 * dt * a r += dt * v a_new = pairwise_accel(m, r) v += 0.5 * dt * a_new a = a_new if (it % yield_every) == 0: U = potential_energy(m, r) K = kinetic_energy(m, v) E = U + K Px, Py, Pz, Pmsm = total_momentum(m, v) yield (t, r, v, a, (U, K, E), (Px, Py, Pz, Pmsm))
# =================== # main # ===================
[ドキュメント] def main(): """ シミュレーションを実行するメイン関数。 コマンドライン引数を解析してパラメータを設定し、シミュレーションを実行します。 計算結果はCSVファイルに保存され、指定に応じてリアルタイムで軌道のプロットを行います。 :param: なし :returns: None """ global solver, dt, nt, fplot, yield_every, enable_ctrlc argv = sys.argv if len(argv) >= 2: solver = argv[1] if len(argv) >= 3: dt = float(argv[2]) if len(argv) >= 4: nt = int(argv[3]) if len(argv) >= 5: fplot = int(argv[4]) if len(argv) >= 6: yield_every = int(argv[5]) if len(argv) >= 7: enable_ctrlc = int(argv[6]) if enable_ctrlc: signal.signal(signal.SIGINT, signal.default_int_handler) else: signal.signal(signal.SIGINT, signal.SIG_IGN) print("Planet simulator NumPy ({})".format(solver)) print("G1 = {}".format(G1)) print("dt = {}, nt = {}, yield_every = {}, Ctrl-C={}".format(dt, nt, yield_every, "stop" if enable_ctrlc else "ignore")) print("") planets = readdb(dbfile) names = [p['Name'] for p in planets] print("Planets:", ", ".join(names)) out_traj = "diffeq2nd_Planet_{}.csv".format(solver) out_cons = "diffeq2nd_Planet_{}_conservation.csv".format(solver) wtraj = csv.writer(open(out_traj, 'w', newline='', encoding='utf-8'), lineterminator='\n') wcon = csv.writer(open(out_cons, 'w', newline='', encoding='utf-8'), lineterminator='\n') lab = ['t'] + sum(([f"x({nm})", f"y({nm})"] for nm in names), []) wtraj.writerow(lab) wcon.writerow(['t','U','K','E','Px','Py','Pz','Pmsm']) if fplot: fig, ax = plt.subplots(1,1) ax.set_aspect('equal', adjustable='box') ax.set_xlim(xgrange); ax.set_ylim(ygrange) ax.set_xlabel('x [AU]'); ax.set_ylabel('y [AU]') nshow = min(7, len(names)) trails_x = [[] for _ in range(nshow)] trails_y = [[] for _ in range(nshow)] lines = [ax.plot([], [], linewidth=0.6)[0] for _ in range(nshow)] gen = md_generator(planets, solver, dt, nt, yield_every=yield_every) print("{:^7}".format('t'), end='') for i in range(0, min(nprint_planets, len(names))): print(" {:^12} {:^12}".format(f"x({names[i]})", f"y({names[i]})"), end='') print("") try: for (t, r, v, a, (U,K,E), (Px,Py,Pz,Pmsm)) in gen: row = [t] xy = r[:, :2] row = [t] for i in range(xy.shape[0]): row.append(float(xy[i,0])) row.append(float(xy[i,1])) wtraj.writerow(row) wcon.writerow([t, U, K, E, Px, Py, Pz, Pmsm]) istep = int(round(t/dt)) if (istep % iprint_interval) == 0: print(f"{t:^7.2f}", end='') for i in range(0, min(nprint_planets, xy.shape[0])): print(" {:>12.4g} {:>12.4g}".format(xy[i,0], xy[i,1]), end='') print("") if fplot: nshow = min(len(lines), xy.shape[0]) for i in range(nshow): trails_x[i].append(xy[i,0]) trails_y[i].append(xy[i,1]) if len(trails_x[i]) > 2000: trails_x[i] = trails_x[i][-1000:] trails_y[i] = trails_y[i][-1000:] lines[i].set_data(trails_x[i], trails_y[i]) ax.set_xlim(xgrange); ax.set_ylim(ygrange) plt.pause(0.01) except KeyboardInterrupt: if enable_ctrlc: print("\nInterrupted by user (Ctrl-C).")
if __name__ == '__main__': start = time.perf_counter() main() end = time.perf_counter() print(f"\nElapsed time: {end - start:.3f} seconds") print("Press ENTER to exit>>", end='') try: input() except EOFError: pass