import sys
import csv
import openpyxl
import numpy as np
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import sys
from types import SimpleNamespace

plt.rcParams.update({'font.size': 16})  # フォントサイズ設定

def initialize():
    cfg = SimpleNamespace()
    cfg.xlabel = 'x'
    cfg.ylabel = 'y'  

    # true coefficients
    cfg.true_a = [1.0, -2.0, 3.0]
    # Basis functions 定義
    cfg.basis_functions = [
        lambda x: np.sin(x),
        lambda x: x**2,
        lambda x: np.exp(-x)
    ]
    cfg.p = len(cfg.basis_functions)
    cfg.n = 100
    cfg.sigma_noise = 0.3

    cfg.tol = 1e-4
    cfg.max_iter = 20
    cfg.draws = 10000
    cfg.tune = 500
    cfg.seed = 42
    cfg.prior_mu = 0.0
    cfg.prior_std = 5.0
    cfg.hdi_prob = 0.68

    return cfg

def update_vars(cfg):
    #argvからkey=valを読み込んでcfgに代入
    for arg in sys.argv[1:]:
        if '=' in arg:
            key, value = arg.split('=')
            if hasattr(cfg, key):
                try:
                    # cfgの属性を適切な型に変換
                    if isinstance(getattr(cfg, key), bool):
                        setattr(cfg, key, value.lower() == 'true')
                    elif isinstance(getattr(cfg, key), int):
                        setattr(cfg, key, int(value))
                    elif isinstance(getattr(cfg, key), float):
                        setattr(cfg, key, float(value))
                    else:
                        setattr(cfg, key, value)
                except ValueError:
                    print(f"Invalid value for {key}: {value}")
            else:
                print(f"Unknown parameter: {key}")

def model(x, coeffs, cfg):
    y = np.zeros_like(x)
    for i, f in enumerate(cfg.basis_functions):
        y += coeffs[i] * f(x)
    
    return y

def generate_data(cfg):
    np.random.seed(cfg.seed)
    x_data = np.linspace(0, 1, cfg.n)
    y_true = model(x_data, cfg.true_a, cfg)
    y_data = y_true + np.random.normal(0, cfg.sigma_noise, size=cfg.n)

    return x_data, y_data, y_true

def estimate_noise(y_obs, y_pred):
    return np.std(y_obs - y_pred)

def run_mcmc(x, y_obs, cfg):
    with pm.Model() as model_ctx:
        coeffs = []
        for i in range(cfg.p):
            coeffs.append(pm.Normal(f"a{i}", mu=cfg.prior_mu, sigma=cfg.prior_std))
        sigma = pm.HalfNormal("sigma", sigma=1)
        mu = sum(coeffs[i] * cfg.basis_functions[i](x) for i in range(cfg.p))
        pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)
        trace = pm.sample(
            draws=cfg.draws,
            tune=cfg.tune,
            target_accept=0.9,
            return_inferencedata=True,
            random_seed=cfg.seed
        )
    return trace


def extract_summary(trace, cfg):
    var_names = [f"a{i}" for i in range(cfg.p)] + ["sigma"]
    summary = az.summary(trace, var_names=var_names, round_to=5)
    return summary


def plot_results(x, y_obs, y_true, trace, cfg):
    x_pred = np.linspace(min(x), max(x), 200)
    post = trace.posterior.stack(samples=("chain", "draw"))
    nsamples = post.samples.size

    y_preds = np.zeros((nsamples, len(x_pred)))
    y_preds_wo_noise = np.zeros_like(y_preds)
    for idx in range(nsamples):
        coeffs = [float(post[f"a{i}"].values[idx]) for i in range(cfg.p)]
        sigma_i = float(post["sigma"].values[idx])
        y_model = model(x_pred, coeffs, cfg)
        y_preds_wo_noise[idx] = y_model
        y_preds[idx] = y_model + np.random.normal(0, sigma_i, size=len(x_pred))

    y_median = np.median(y_preds, axis=0)
    hdi_pred = az.hdi(y_preds, hdi_prob=cfg.hdi_prob)
    hdi_model = az.hdi(y_preds_wo_noise, hdi_prob=cfg.hdi_prob)

    plt.figure(figsize=(10, 6))
    plt.scatter(x, y_obs, label="Observed Data", color="blue", alpha=0.5)
    if y_true is not None:
        y_true_pred = model(x_pred, cfg.true_a, cfg)
    else:
        y_true_pred = None

    if y_true_pred is not None:
        plt.plot(x_pred, y_true_pred, label="True Function", color="red")
    plt.plot(x_pred, y_median, label="Posterior Median", color="green", linestyle="--")

    plt.fill_between(
        x_pred, hdi_pred[:, 0], hdi_pred[:, 1],
        color='gray', alpha=0.3,
        label="1σ Predictive Interval (incl. noise)"
    )
    plt.fill_between(
        x_pred, hdi_model[:, 0], hdi_model[:, 1],
        color='orange', alpha=0.2,
        label="1σ Model Interval (w/o noise)"
    )

    plt.xlabel(cfg.xlabel)
    plt.ylabel(cfg.ylabel)
    plt.title(f"Bayesian Nonlinear Regression (MCMC): p={cfg.p}")
    plt.legend()
    plt.grid(True)
    plt.show()

def main():
    cfg = initialize()
    update_vars(cfg)
    x, y_obs, y_true = generate_data(cfg)

    for i in range(cfg.max_iter):
        trace = run_mcmc(x, y_obs, cfg)
        post = trace.posterior.stack(samples=("chain", "draw"))
        coeffs_mean = [float(post[f"a{i}"].mean().values) for i in range(cfg.p)]
        y_pred = model(x, coeffs_mean, cfg)
        sigma_new = estimate_noise(y_obs, y_pred)

        rel_diff = abs(sigma_new - cfg.sigma_noise) / cfg.sigma_noise
        cfg.sigma_noise = sigma_new
        if rel_diff < cfg.tol:
            print(f"収束: iteration={i+1}, σ_noise = {sigma_new:.5f}")
            break

    summary = extract_summary(trace, cfg)
    print("\n=== ベイズ非線形回帰（MCMC）パラメータ推定値 ===")
    for name, row in summary.iterrows():
        print(f"  {name}: mean = {row['mean']:.5f} ± {row['sd']:.5f}")
    print("==============================================\n")

    plot_results(x, y_obs, y_true, trace, cfg)

if __name__ == "__main__":
    main()
