# ============================================================
# Planet simulator (NumPy): Euler / Heun / Verlet / velocity-Verlet
# ============================================================
"""
惑星の軌道シミュレーションを実行するモジュール。
本モジュールは、Euler、Heun、Verlet、Velocity-Verlet法などの数値解法を用いて、
複数惑星の運動方程式を解き、その軌跡を計算・描画します。ジェネレータを通じて
シミュレーション状態を順次取得することが可能です。
関連リンク:
:doc:`diffeq2nd_planet_generator_usage`
"""
import sys
import csv
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import signal
import time
matplotlib.use("TkAgg") # または "QtAgg"
# ===================
# constants
# ===================
G = 6.67259e-11 # Nm2/kg2
DayToSecond = 60 * 60 * 24 # s
AstronomicalUnit = 1.49597870e11 # m
AU = AstronomicalUnit
G1 = G * DayToSecond * DayToSecond / AU / AU / AU
# ===================
# defaults / params
# ===================
solver = 'vverlet' # 'Euler' | 'Heun' | 'Verlet' | 'vverlet'
fplot = 1
dbfile = 'planet_db.csv'
dt = 0.1
nt = 20000
iprint_interval = 100
nprint_planets = 4
xgrange = (-5.0, 5.0)
ygrange = (-5.0, 5.0)
yield_every = 10
enable_ctrlc = 1
# ===================
# helpers
# ===================
[ドキュメント]
def readdb(dbfile):
"""
CSVファイルから惑星データを読み込む。
指定されたCSVファイルを開き、各行を辞書として読み込みます。
数値データはfloatに変換して返します。
:param dbfile: str, 読み込むCSVファイルのパス。
:returns: list, 惑星データの辞書を格納したリスト。
"""
rows = []
with open(dbfile, "r", newline='', encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
rows.append(row)
for d in rows:
for k in list(d.keys()):
if k != 'Name':
d[k] = float(d[k])
return rows
[ドキュメント]
def normalize_momentum(m, v):
"""
系全体の運動量をゼロに正規化する。
質量と速度から系全体の運動量を計算し、重心系での速度に変換するよう
各惑星の速度をインプレースで調整します。
:param m: numpy.ndarray, 惑星の質量の配列 (N,)。
:param v: numpy.ndarray, 惑星の速度の配列 (N, 3)。
:returns: None
"""
P = (m[:, None] * v).sum(axis=0)
v -= P[None, :] / m.sum()
[ドキュメント]
def kinetic_energy(m, v):
"""
系の全運動エネルギーを計算する。
:param m: numpy.ndarray, 惑星の質量の配列 (N,)。
:param v: numpy.ndarray, 惑星の速度の配列 (N, 3)。
:returns: float, 計算された全運動エネルギー。
"""
return 0.5 * (m * (v*v).sum(axis=1)).sum()
[ドキュメント]
def pairwise_accel(m, r):
"""
万有引力の法則に基づき各惑星の加速度を計算する。
惑星間の相対距離からニュートン力学に基づく重力加速度を計算して返します。
:param m: numpy.ndarray, 惑星の質量の配列 (N,)。
:param r: numpy.ndarray, 惑星の位置の配列 (N, 3)。
:returns: numpy.ndarray, 各惑星の加速度の配列 (N, 3)。
"""
D = r[None, :, :] - r[:, None, :]
r2 = np.einsum('ijk,ijk->ij', D, D)
np.fill_diagonal(r2, np.inf)
inv_r3 = 1.0 / (r2 * np.sqrt(r2))
W = (m[None, :] * inv_r3)[:, :, None]
a = G1 * np.sum(W * D, axis=1)
return a
[ドキュメント]
def potential_energy(m, r):
"""
系の全ポテンシャルエネルギー(重力位置エネルギー)を計算する。
:param m: numpy.ndarray, 惑星の質量の配列 (N,)。
:param r: numpy.ndarray, 惑星の位置の配列 (N, 3)。
:returns: float, 計算されたポテンシャルエネルギー。
"""
D = r[None, :, :] - r[:, None, :]
r2 = np.einsum('ijk,ijk->ij', D, D)
N = r.shape[0]
iu = np.triu_indices(N, k=1)
dist = np.sqrt(r2[iu])
mi_mj = (m[iu[0]] * m[iu[1]])
return (G1 * mi_mj / dist).sum()
[ドキュメント]
def total_momentum(m, v):
"""
系の全運動量とその二乗平均平方根を計算する。
:param m: numpy.ndarray, 惑星の質量の配列 (N,)。
:param v: numpy.ndarray, 惑星の速度の配列 (N, 3)。
:returns: tuple, 運動量のx, y, z成分と、その二乗平均平方根の値 (Px, Py, Pz, Pmsm)。
"""
P = (m[:, None] * v).sum(axis=0)
Pmsm = np.sqrt((P*P).sum() / (3.0 * max(1, len(m))))
return P[0], P[1], P[2], Pmsm
[ドキュメント]
def initialize(planets):
"""
惑星データからシミュレーション用の初期状態を生成する。
:param planets: list, 辞書形式の惑星データリスト。
:returns: tuple, 名前(names), 質量(m), 位置(r), 速度(v), 加速度(a)を含むタプル。
"""
names = np.array([p['Name'] for p in planets], dtype=object)
m = np.array([p['Mass'] for p in planets], dtype=float)
r = np.zeros((len(planets), 3), dtype=float)
v = np.zeros((len(planets), 3), dtype=float)
r[:, 0] = np.array([p['Revolution Radius'] / AU for p in planets])
v[:, 1] = np.array([p['Revolution Velocity'] * DayToSecond / AU for p in planets])
normalize_momentum(m, v)
a = pairwise_accel(m, r)
return names, m, r, v, a
# ===================
# generator
# ===================
[ドキュメント]
def md_generator(planets, solver, dt, nt, yield_every=1):
"""
惑星シミュレーションを実行し、状態を順次返すジェネレータ。
指定された数値解法(Euler, Heun, Verlet, Velocity-Verlet)を用いて
指定回数の時間発展を計算し、定期的なステップ間隔で現在の状態を生成します。
:param planets: list, 辞書形式の惑星データリスト。
:param solver: str, 使用する数値解法名 ('Euler', 'Heun', 'Verlet', 'vverlet')。
:param dt: float, タイムステップ幅。
:param nt: int, 総ステップ数。
:param yield_every: int, 状態をyieldするステップ間隔。
:returns: generator, (時間, 位置, 速度, 加速度, (U, K, E), (Px, Py, Pz, Pmsm)) のタプルを生成。
"""
s = (solver or '').lower()
if s == 'velet':
s = 'vverlet'
names, m, r, v, a = initialize(planets)
U = potential_energy(m, r)
K = kinetic_energy(m, v)
E = U + K
Px, Py, Pz, Pmsm = total_momentum(m, v)
yield (0.0, r, v, a, (U, K, E), (Px, Py, Pz, Pmsm))
if s == 'verlet':
r_prev = r - dt * v
for it in range(1, nt+1):
t = it * dt
if s == 'euler':
v += dt * a
r += dt * v
a = pairwise_accel(m, r)
elif s == 'heun':
v_pred = v + dt * a
r_pred = r + dt * v
a_pred = pairwise_accel(m, r_pred)
v += 0.5 * dt * (a + a_pred)
r += 0.5 * dt * (v + v_pred)
a = pairwise_accel(m, r)
elif s == 'verlet':
r_new = 2.0 * r - r_prev + (dt*dt) * a
v[:] = (r_new - r_prev) / (2.0 * dt)
r_prev, r = r, r_new
a = pairwise_accel(m, r)
else:
v += 0.5 * dt * a
r += dt * v
a_new = pairwise_accel(m, r)
v += 0.5 * dt * a_new
a = a_new
if (it % yield_every) == 0:
U = potential_energy(m, r)
K = kinetic_energy(m, v)
E = U + K
Px, Py, Pz, Pmsm = total_momentum(m, v)
yield (t, r, v, a, (U, K, E), (Px, Py, Pz, Pmsm))
# ===================
# main
# ===================
[ドキュメント]
def main():
"""
シミュレーションを実行するメイン関数。
コマンドライン引数を解析してパラメータを設定し、シミュレーションを実行します。
計算結果はCSVファイルに保存され、指定に応じてリアルタイムで軌道のプロットを行います。
:param: なし
:returns: None
"""
global solver, dt, nt, fplot, yield_every, enable_ctrlc
argv = sys.argv
if len(argv) >= 2: solver = argv[1]
if len(argv) >= 3: dt = float(argv[2])
if len(argv) >= 4: nt = int(argv[3])
if len(argv) >= 5: fplot = int(argv[4])
if len(argv) >= 6: yield_every = int(argv[5])
if len(argv) >= 7: enable_ctrlc = int(argv[6])
if enable_ctrlc:
signal.signal(signal.SIGINT, signal.default_int_handler)
else:
signal.signal(signal.SIGINT, signal.SIG_IGN)
print("Planet simulator NumPy ({})".format(solver))
print("G1 = {}".format(G1))
print("dt = {}, nt = {}, yield_every = {}, Ctrl-C={}".format(dt, nt, yield_every, "stop" if enable_ctrlc else "ignore"))
print("")
planets = readdb(dbfile)
names = [p['Name'] for p in planets]
print("Planets:", ", ".join(names))
out_traj = "diffeq2nd_Planet_{}.csv".format(solver)
out_cons = "diffeq2nd_Planet_{}_conservation.csv".format(solver)
wtraj = csv.writer(open(out_traj, 'w', newline='', encoding='utf-8'), lineterminator='\n')
wcon = csv.writer(open(out_cons, 'w', newline='', encoding='utf-8'), lineterminator='\n')
lab = ['t'] + sum(([f"x({nm})", f"y({nm})"] for nm in names), [])
wtraj.writerow(lab)
wcon.writerow(['t','U','K','E','Px','Py','Pz','Pmsm'])
if fplot:
fig, ax = plt.subplots(1,1)
ax.set_aspect('equal', adjustable='box')
ax.set_xlim(xgrange); ax.set_ylim(ygrange)
ax.set_xlabel('x [AU]'); ax.set_ylabel('y [AU]')
nshow = min(7, len(names))
trails_x = [[] for _ in range(nshow)]
trails_y = [[] for _ in range(nshow)]
lines = [ax.plot([], [], linewidth=0.6)[0] for _ in range(nshow)]
gen = md_generator(planets, solver, dt, nt, yield_every=yield_every)
print("{:^7}".format('t'), end='')
for i in range(0, min(nprint_planets, len(names))):
print(" {:^12} {:^12}".format(f"x({names[i]})", f"y({names[i]})"), end='')
print("")
try:
for (t, r, v, a, (U,K,E), (Px,Py,Pz,Pmsm)) in gen:
row = [t]
xy = r[:, :2]
row = [t]
for i in range(xy.shape[0]):
row.append(float(xy[i,0]))
row.append(float(xy[i,1]))
wtraj.writerow(row)
wcon.writerow([t, U, K, E, Px, Py, Pz, Pmsm])
istep = int(round(t/dt))
if (istep % iprint_interval) == 0:
print(f"{t:^7.2f}", end='')
for i in range(0, min(nprint_planets, xy.shape[0])):
print(" {:>12.4g} {:>12.4g}".format(xy[i,0], xy[i,1]), end='')
print("")
if fplot:
nshow = min(len(lines), xy.shape[0])
for i in range(nshow):
trails_x[i].append(xy[i,0])
trails_y[i].append(xy[i,1])
if len(trails_x[i]) > 2000:
trails_x[i] = trails_x[i][-1000:]
trails_y[i] = trails_y[i][-1000:]
lines[i].set_data(trails_x[i], trails_y[i])
ax.set_xlim(xgrange); ax.set_ylim(ygrange)
plt.pause(0.01)
except KeyboardInterrupt:
if enable_ctrlc:
print("\nInterrupted by user (Ctrl-C).")
if __name__ == '__main__':
start = time.perf_counter()
main()
end = time.perf_counter()
print(f"\nElapsed time: {end - start:.3f} seconds")
print("Press ENTER to exit>>", end='')
try:
input()
except EOFError:
pass