import numpy as np
import matplotlib.pyplot as plt

def model(x, ai):
    return ai[0] + ai[1]*x

def generate_data(N=100, beta_true=np.array([2.0, 1.5]), noise_sigma=0.3, seed=0):
    np.random.seed(seed)
    x = np.linspace(0, 10, N)
    noise = noise_sigma * np.random.normal(loc=0.0, scale=2.0, size=N)
    y = model(x, beta_true) + noise
    X = np.vstack([np.ones(N), x]).T  # デザイン行列
    return X, y, x

def plot_variances(tau2_list, bayes_vars, ols_var):
    plt.figure(figsize=(10, 5))
    plt.plot(tau2_list, bayes_vars[:, 0], label='Bayes Var(a₀)', color='blue')
    plt.plot(tau2_list, bayes_vars[:, 1], label='Bayes Var(a₁)', color='green')
    plt.hlines(ols_var[0, 0], tau2_list[0], tau2_list[-1], linestyles='dashed', color='blue', label='OLS Var(a₀)')
    plt.hlines(ols_var[1, 1], tau2_list[0], tau2_list[-1], linestyles='dashed', color='green', label='OLS Var(a₁)')
    plt.xscale('log')
    plt.xlabel(r'$\tau^2$ (prior variance)')
    plt.ylabel('Parameter variance')
    plt.title('Bayesian vs OLS Parameter Variance')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def compute_ols_variance(X, y, sigma):
    XtX_inv = np.linalg.inv(X.T @ X)
    beta_hat = XtX_inv @ X.T @ y
    var_beta = sigma**2 * XtX_inv
    return var_beta, beta_hat

def compute_bayesian_variances(X, y, sigma, tau2_list):
    bayes_vars = []
    bayes_means = []
    for tau2 in tau2_list:
        precision = (1 / sigma**2) * (X.T @ X + (1 / tau2) * np.eye(X.shape[1]))
        cov_post = np.linalg.inv(precision)
        mean_post = cov_post @ ((1 / sigma**2) * X.T @ y)
        var_a0, var_a1 = np.diag(cov_post)
        bayes_vars.append([var_a0, var_a1])
        bayes_means.append([mean_post[0], mean_post[1]])
        print(f"tau^2 = {tau2:8.2e} | "
              f"Mean(a0) = {mean_post[0]:.6f}, Var(a0) = {var_a0:.6f} | "
              f"Mean(a1) = {mean_post[1]:.6f}, Var(a1) = {var_a1:.6f}")
    return np.array(bayes_vars), np.array(bayes_means)

def main():
    N = 100
    beta_true = np.array([2.0, 1.5])
    noise_sigma = 0.3
    sigma = noise_sigma * 2.0  # ノイズの分散から σ

    tau2_list = np.logspace(-2, 4, 50)

    X, y, x = generate_data(N=N, beta_true=beta_true, noise_sigma=noise_sigma)

    ols_var, ols_mean = compute_ols_variance(X, y, sigma)
    bayes_vars, bayes_means = compute_bayesian_variances(X, y, sigma, tau2_list)

    print(f"OLS Estimate: a0 = {ols_mean[0]:.6f}, a1 = {ols_mean[1]:.6f}")
    print(f"OLS Var    : a0 = {ols_var[0, 0]:.6f}, a1 = {ols_var[1, 1]:.6f}")

    plot_variances(tau2_list, bayes_vars, ols_var)

if __name__ == "__main__":
    main()
    