import numpy as np
import matplotlib.pyplot as plt
from numpy.linalg import inv, norm
import matplotlib

# 日本語ラベル用フォント（お使いの環境に合わせて変更してください）
matplotlib.rcParams['font.family'] = 'MS Gothic'

# --- 設定 ---
np.random.seed(42)              # 再現性のための乱数シード
n_features = 3                  # 特徴量の次元
n_blocks = 10                   # データブロック数（更新回数）
block_size = 20                 # 各ブロックのサンプル数
sigma = 1.0                     # 観測ノイズの標準偏差
tau2 = 10.0                     # 事前分布の分散（共分散は tau2 × I）

# 真のパラメータベクトル（ground truth）
w_true = np.random.randn(n_features)

# 初期事前分布の平均と共分散
A0 = np.zeros(n_features)
Sigma0 = tau2 * np.eye(n_features)

# --- データ生成（すべてのブロック） ---
X_all = []
y_all = []

for _ in range(n_blocks):
    X_block = np.random.randn(block_size, n_features)
    noise = sigma * np.random.randn(block_size)
    y_block = X_block @ w_true + noise
    X_all.append(X_block)
    y_all.append(y_block)

X_all = np.vstack(X_all)
y_all = np.concatenate(y_all)

# --- 方法1: 全データ一括ベイズ推論 ---
Sigma_batch_inv = inv(Sigma0) + (1 / sigma**2) * X_all.T @ X_all
Sigma_batch = inv(Sigma_batch_inv)
A_batch = Sigma_batch @ (inv(Sigma0) @ A0 + (1 / sigma**2) * X_all.T @ y_all)

# --- 方法2: 逐次ベイズ更新（修正後） ---
A_seq = A0.copy()
Sigma_seq = Sigma0.copy()
history_seq = []

for i in range(n_blocks):
    X_block = X_all[i * block_size:(i + 1) * block_size]
    y_block = y_all[i * block_size:(i + 1) * block_size]

    Sigma_seq_old_inv = inv(Sigma_seq)  # 修正：古いΣで更新計算
    Sigma_seq = inv(Sigma_seq_old_inv + (1 / sigma**2) * X_block.T @ X_block)
    A_seq = Sigma_seq @ (Sigma_seq_old_inv @ A_seq + (1 / sigma**2) * X_block.T @ y_block)

    history_seq.append((A_seq.copy(), Sigma_seq.copy()))

# --- 方法3: リスタート更新（修正後） ---
A_restart = A0.copy()
Sigma_restart = Sigma0.copy()
history_restart = []

for i in range(n_blocks):
    X_block = X_all[i * block_size:(i + 1) * block_size]
    y_block = y_all[i * block_size:(i + 1) * block_size]

    Sigma_restart_old_inv = inv(Sigma_restart)
    Sigma_restart = inv(Sigma_restart_old_inv + (1 / sigma**2) * X_block.T @ X_block)
    A_restart = Sigma_restart @ (Sigma_restart_old_inv @ A_restart + (1 / sigma**2) * X_block.T @ y_block)

    history_restart.append((A_restart.copy(), Sigma_restart.copy()))

# --- 最終差分（L2ノルム/Frobenius） ---
mean_seq_diff = norm(A_seq - A_batch)
mean_restart_diff = norm(A_restart - A_batch)
cov_seq_diff = norm(Sigma_seq - Sigma_batch, ord='fro')
cov_restart_diff = norm(Sigma_restart - Sigma_batch, ord='fro')

print("パラメータ平均の差（逐次 vs 一括）:", mean_seq_diff)
print("パラメータ平均の差（リスタート vs 一括）:", mean_restart_diff)
print("共分散行列の差（逐次 vs 一括）:", cov_seq_diff)
print("共分散行列の差（リスタート vs 一括）:", cov_restart_diff)

# --- 差分の推移を可視化 ---
steps = np.arange(1, n_blocks + 1)
mean_diffs_seq = [norm(A - A_batch) for A, _ in history_seq]
mean_diffs_restart = [norm(A - A_batch) for A, _ in history_restart]

plt.figure(figsize=(8, 5))
plt.plot(steps, mean_diffs_seq, label="逐次更新", marker='o')
plt.plot(steps, mean_diffs_restart, label="リスタート更新", marker='s')
plt.xlabel("ステップ数（ブロック数）")
plt.ylabel("事後平均の差（A vs A_batch）")
plt.title("逐次/リスタート vs 一括ベイズの平均ベクトルの収束")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
