import numpy as np
from scipy.stats import norm  
import matplotlib.pyplot as plt

# 日本語フォント設定
plt.rcParams['font.family'] = 'MS Gothic'
plt.rcParams['axes.unicode_minus'] = False

def main():
    # 1. データ生成
    np.random.seed(0)
    x = np.linspace(0, 1, 20)
    true_a, true_b = 1.0, 2.0
    sigma_noise = 0.2
    y = true_a + true_b * x + np.random.normal(0, sigma_noise, size=x.shape)

    # 2. X行列作成（バイアス項付き）
    X = np.vstack([np.ones_like(x), x]).T

    # 3. 事前分布（非情報的）
    mu0 = np.zeros(2)
    Sigma0 = np.eye(2) * 1e3

    # 4. ノイズの分散
    sigma2 = sigma_noise ** 2

    # 5. 事後分布の計算（解析解）
    Sigma0_inv = np.linalg.inv(Sigma0)
    Sigma_n = np.linalg.inv(X.T @ X / sigma2 + Sigma0_inv)
    mu_n = Sigma_n @ (X.T @ y / sigma2 + Sigma0_inv @ mu0)

    # 6. パラメータ事後分布からサンプルを生成
    param_samples = np.random.multivariate_normal(mu_n, Sigma_n, size=5000)
    a_samples = param_samples[:, 0]
    b_samples = param_samples[:, 1]

    # 6.1 事後平均と標準偏差の出力
    print("パラメータの事後分布（平均 ± 標準偏差）:")
    print(f"  a（切片）: {np.mean(a_samples):.4f} ± {np.std(a_samples):.4f}")
    print(f"  b（傾き）: {np.mean(b_samples):.4f} ± {np.std(b_samples):.4f}")

# 6.2 ヒストグラム + 理論分布の重ね描き
    plt.figure(figsize=(12, 5))

# a（切片）
    plt.subplot(1, 2, 1)
    plt.hist(a_samples, bins=40, density=True, color='skyblue', label="サンプル分布")
    a_range = np.linspace(np.min(a_samples), np.max(a_samples), 300)
    a_pdf = norm.pdf(a_range, loc=mu_n[0], scale=np.sqrt(Sigma_n[0, 0]))
    plt.plot(a_range, a_pdf, color='blue', lw=2, label="理論分布")
    plt.title("切片 a の事後分布")
    plt.xlabel("a")
    plt.ylabel("確率密度")
    plt.grid(True)
    plt.legend()

# b（傾き）
    plt.subplot(1, 2, 2)
    plt.hist(b_samples, bins=40, density=True, color='salmon', label="サンプル分布")
    b_range = np.linspace(np.min(b_samples), np.max(b_samples), 300)
    b_pdf = norm.pdf(b_range, loc=mu_n[1], scale=np.sqrt(Sigma_n[1, 1]))
    plt.plot(b_range, b_pdf, color='red', lw=2, label="理論分布")
    plt.title("傾き b の事後分布")
    plt.xlabel("b")
    plt.ylabel("確率密度")
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

    # 7. 予測と信頼区間の計算
    x_plot = np.linspace(0, 1, 100)
    X_plot = np.vstack([np.ones_like(x_plot), x_plot]).T

    y_mean = X_plot @ mu_n
    var_structural = np.sum(X_plot @ Sigma_n * X_plot, axis=1)     # 構造誤差のみ
    var_total = var_structural + sigma2                             # ノイズ込み

    # 信頼区間（95%）
    ci_struct_upper = y_mean + 1.96 * np.sqrt(var_structural)
    ci_struct_lower = y_mean - 1.96 * np.sqrt(var_structural)

    ci_total_upper = y_mean + 1.96 * np.sqrt(var_total)
    ci_total_lower = y_mean - 1.96 * np.sqrt(var_total)

    # 8. 可視化
    plt.figure(figsize=(8, 5))
    plt.scatter(x, y, label="観測データ", color="black")

    # 予測平均
    plt.plot(x_plot, y_mean, label="事後平均（予測）", color="blue")

    # 信頼区間（ノイズなし：構造誤差）
    plt.fill_between(x_plot, ci_struct_lower, ci_struct_upper, color="orange", alpha=0.3, label="95% 信頼区間（ノイズなし）")

    # 信頼区間（ノイズ込み）
    plt.fill_between(x_plot, ci_total_lower, ci_total_upper, color="blue", alpha=0.2, label="95% 信頼区間（ノイズ込み）")

    plt.xlabel("x")
    plt.ylabel("y")
    plt.title("ベイズ線形回帰：予測と不確かさ（構造誤差 vs ノイズ込み）")
    plt.legend()
    plt.grid(True)
    plt.show()

if __name__ == '__main__':
    main()
