import numpy as np
import pymc as pm
import arviz as az
import matplotlib.pyplot as plt
import multiprocessing

# --- モデル関数定義 ---
def model_true(x, a_true):
    """
    真のモデル: y = a_true[0] * exp(a_true[1] * x)
    """
    return a_true[0] * np.exp(a_true[1] * x)

def model_trial(x, a):
    """
    試行モデル: y = a[0] * exp(a[1] * x) + a[2] + a[3] * x
    """
    return a[0] * np.exp(a[1] * x) + a[2] + a[3] * x

# --- データ生成 ---
def generate_data(n=100, noise_sigma=0.2):
    np.random.seed(0)
    x = np.linspace(0, 1, n)
    # 真のパラメータ (a_true[0], a_true[1])
    a_true = np.array([2.0, -1.5])
    y_true = model_true(x, a_true)
    y_obs = y_true + np.random.normal(0, noise_sigma, size=n)
    return x, y_obs, y_true, a_true

# --- ARD 階層ベイズモデル ---
def run_ard_model(x, y_obs, p=4, cores=1):
    with pm.Model() as model:
        # ARD hyperpriors for each of p parameters
        lambda_hp = pm.Gamma("lambda_hp", alpha=1e-3, beta=1e-3, shape=p)
        # Parameter priors: a_i ~ Normal(0, lambda_hp[i]^{-1/2})
        a = pm.Normal("a", mu=0, sigma=lambda_hp**-0.5, shape=p)
        # Noise prior
        sigma = pm.HalfNormal("sigma", sigma=1.0)
        # Trial model expectation
        mu = model_trial(x, a)
        # Likelihood
        pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)
        # Sampling
        trace = pm.sample(draws=2000,
                          tune=1000,
                          target_accept=0.9,
                          random_seed=42,
                          cores=cores)
    return trace

if __name__ == '__main__':
    multiprocessing.freeze_support()

    # データ生成
    x, y_obs, y_true, a_true = generate_data()

    # ARD モデル実行
    p = 4
    trace = run_ard_model(x, y_obs, p=p, cores=1)

    # サマリ出力
    print("\n=== ARD 階層ベイズ 回帰結果 ===")
    summary_df = az.summary(trace,
                             var_names=["lambda_hp", "a", "sigma"],
                             hdi_prob=0.95,
                             round_to=4)
    print(summary_df)

    # Trace プロット
    az.plot_trace(trace, var_names=["lambda_hp", "a", "sigma"])
    plt.show()

    # 事後平均と HDI を取得
    a_mean = trace.posterior['a'].mean(dim=("chain","draw")).values
    a_hdi = az.hdi(trace, var_names=['a'], hdi_prob=0.95)['a'].values
    lam_mean = trace.posterior['lambda_hp'].mean(dim=("chain","draw")).values

    # パラメータ選択の判断
    print("\n=== パラメータ選択の判断 ===")
    for i in range(p):
        low, high = a_hdi[i]
        mean = a_mean[i]
        lam = lam_mean[i]
        print(f"a[{i}]: mean={mean:.4f}, 95% HDI=({low:.4f},{high:.4f}), lambda_hp={lam:.1f}")
        if abs(mean) < 0.05 and low < 0 < high and lam > 100:
            print("  → パラメータ不要候補")

    # 真のモデルのパラメータ表示
    print(f"\nTrue a: {a_true}")

    # 事後予測プロット
    x_pred = np.linspace(0, 1, 200)
    y_pred = model_trial(x_pred, a_mean)

    plt.figure(figsize=(8,4))
    plt.scatter(x, y_obs, c='blue', alpha=0.5, label='Observed')
    plt.plot(x_pred, model_true(x_pred, a_true), 'k-', label='True model')
    plt.plot(x_pred, y_pred, 'r--', label='ARD Posterior mean')
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('ARD Sparse Modeling: Trial vs True')
    plt.legend()
    plt.grid(True)
    plt.show()