import sys
import csv
import openpyxl
import numpy as np
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
import pytensor.tensor as at
from types import SimpleNamespace

plt.rcParams.update({'font.size': 16})  # フォントサイズ設定

def initialize():
    cfg = SimpleNamespace()
    cfg.infile = None
    cfg.xlabel = 'x'
    cfg.ylabel = 'y'
    cfg.model_type = 'poly'      # 'poly' or 'peak'
    # 真のパラメータ（合成データ用）
    cfg.true_a = [1.0, -2.0, 3.0]
    # NumPy用基底（プロット＆推定後ノイズ自己無撞着用）
    cfg.basis_functions_np = [
        np.sin,
        lambda x: x**2,
        lambda x: np.exp(-x)
    ]
    cfg.basis_functions_tensor = [
        at.sin,
        lambda x: x**2,
        lambda x: at.exp(-x)
    ]
    cfg.p = len(cfg.basis_functions_np)
    cfg.n = 100
    cfg.sigma_noise = 0.3

    cfg.tol = 1e-4
    cfg.max_iter = 20
    cfg.draws = 10000
    cfg.tune = 500
    cfg.seed = 42
    cfg.prior_mu = 0.0
    cfg.prior_std = 5.0
    cfg.hdi_prob = 0.68

    return cfg

def update_vars(cfg):
    for arg in sys.argv[1:]:
        if '=' in arg:
            key, value = arg.split('=', 1)
            if hasattr(cfg, key):
                try:
                    v0 = getattr(cfg, key)
                    if isinstance(v0, bool):
                        setattr(cfg, key, value.lower() == 'true')
                    elif isinstance(v0, int):
                        setattr(cfg, key, int(value))
                    elif isinstance(v0, float):
                        setattr(cfg, key, float(value))
                    else:
                        setattr(cfg, key, value)
                except ValueError:
                    print(f"Invalid value for {key}: {value}")
            else:
                print(f"Unknown parameter: {key}")

def read_data(infile):
    ext = infile.split('.')[-1].lower()
    labels, data_list = [], []
    if ext == 'csv':
        print(f"Read [{infile}] as CSV file")
        with open(infile, 'r', encoding='utf-8') as f:
            reader = csv.reader(f)
            labels = next(reader)
            data_list = [row for row in reader]
    elif ext == 'txt':
        print(f"Read [{infile}] as TXT file")
        with open(infile, 'r', encoding='utf-8') as f:
            reader = csv.reader(f, delimiter='\t')
            labels = next(reader)
            data_list = [row for row in reader]
    elif ext == 'xlsx':
        print(f"Read [{infile}] as XLSX file")
        wb = openpyxl.load_workbook(infile, data_only=True)
        sheet = wb.active
        labels = [cell.value for cell in sheet[1]]
        for row in sheet.iter_rows(min_row=2, values_only=True):
            data_list.append(list(row))
    else:
        raise ValueError("Unsupported file format.")
    return labels, [[float(v) for v in row] for row in data_list]

def generate_data(cfg, infile=None):
    if infile:
        labels, data = read_data(infile)
        x = np.array([r[0] for r in data])
        y_obs = np.array([r[1] for r in data])
        y_true = None
        cfg.model_type = 'peak'
        cfg.p = 4
        cfg.prior_mu = [1.0, max(y_obs), 0.0, 1.0]
        cfg.prior_std = 10.0
    else:
        np.random.seed(cfg.seed)
        x = np.linspace(0, 1, cfg.n)
        y_true = model(x, cfg.true_a, cfg)
        y_obs = y_true + np.random.normal(0, cfg.sigma_noise, size=cfg.n)
    return x, y_obs, y_true

def model(x, coeffs, cfg, list_type = 'numpy'):
    if cfg.model_type == 'poly':
        if list_type == 'numpy':
            y = np.zeros_like(x)
            for i, f in enumerate(cfg.basis_functions_np):
                y += coeffs[i] * f(x)
        else:
            y = sum(
                coeffs[j] * cfg.basis_functions_tensor[j](x) for j in range(cfg.p)
            )
    else:  # peak
        a, b, c, d = coeffs
        if list_type == 'numpy':
            y = a + b * np.exp(-((x - c)/d)**2)
        else:
            y = a + b * at.exp(-((x - c) / d) ** 2)

    return y

def estimate_noise(y_obs, y_pred):
    return np.std(np.asarray(y_obs) - np.asarray(y_pred))

def run_mcmc(x, y_obs, cfg):
    with pm.Model() as model_ctx:
        x_shared = pm.Data("x_shared", x)

        coeffs = []
        for i in range(cfg.p):
            # 元の線形事前平均・標準偏差を取り出し
            if isinstance(cfg.prior_mu, (list, tuple)):
                m0 = cfg.prior_mu[i]
            else:
                m0 = cfg.prior_mu
            if isinstance(cfg.prior_std, (list, tuple)):
                s0 = cfg.prior_std[i]
            else:
                s0 = cfg.prior_std

            # a0, a1, a2 には LogNormal を使う
            if cfg.model_type == 'peak':
                if i == 0:  # background: TrunatedNormal
                    term = pm.TruncatedNormal(f"a{i}", mu=m0, sigma=s0, lower=0.0)
                elif i == 1 or i == 3: # intensity, width
                    # σ_log と μ_log を計算
                    sigma_log = np.sqrt(np.log(1 + (s0 / m0)**2))
                    mu_log    = np.log(m0) - 0.5 * sigma_log**2
                    term = pm.LogNormal(f"a{i}", mu=mu_log, sigma=sigma_log)
                else:
                    term = pm.Normal(f"a{i}", mu=m0, sigma=s0)
            else:
                term = pm.Normal(f"a{i}", mu=m0, sigma=s0)
            
            coeffs.append(term)

        sigma = pm.HalfNormal("sigma", sigma=1.0)
        mu = model(x_shared, coeffs, cfg, list_type = 'tensor')

        pm.Normal("y", mu=mu, sigma=sigma, observed=y_obs)
        trace = pm.sample(
            draws=cfg.draws,
            tune=cfg.tune,
            target_accept=0.9,
            random_seed=cfg.seed,
            return_inferencedata=True
        )

    return trace

def extract_summary(trace, cfg):
    return az.summary(trace, var_names=[f"a{i}" for i in range(cfg.p)] + ["sigma"],
                      round_to=5)

def plot_results(x, y_obs, y_true, trace, cfg):
    x_pred = np.linspace(min(x), max(x), 200)
    post = trace.posterior.stack(samples=("chain", "draw"))
    ns = post.sizes["samples"]

    y_preds = np.zeros((ns, len(x_pred)))
    y_preds_wo_noise = np.zeros_like(y_preds)

    for idx in range(ns):
        coeffs_i = [float(post[f"a{i}"].values[idx]) for i in range(cfg.p)]
        sigma_i = float(post["sigma"].values[idx])
        y_model = model(x_pred, coeffs_i, cfg)
        y_preds_wo_noise[idx] = y_model
        y_preds[idx] = y_model + np.random.normal(0, sigma_i, size=len(x_pred))

    y_med = np.median(y_preds, axis=0)
    hdi_pred = az.hdi(y_preds, hdi_prob=cfg.hdi_prob)
    hdi_model = az.hdi(y_preds_wo_noise, hdi_prob=cfg.hdi_prob)

    plt.figure(figsize=(10,6))
    plt.scatter(x, y_obs, label="Observed Data", alpha=0.5)

    if y_true is not None:
        y_true_pred = model(x_pred, cfg.true_a, cfg)
        plt.plot(x_pred, y_true_pred, label="True Function", color="red")

    plt.plot(x_pred, y_med, linestyle="--", label="Posterior Median", color="green")
    plt.fill_between(x_pred, hdi_pred[:,0], hdi_pred[:,1],
                     alpha=0.3, label="1σ Predictive Interval")
    plt.fill_between(x_pred, hdi_model[:,0], hdi_model[:,1],
                     alpha=0.2, label="1σ Model Interval")

    plt.xlabel(cfg.xlabel)
    plt.ylabel(cfg.ylabel)
    plt.title(f"{cfg.model_type.capitalize()} Model (p={cfg.p})")
    plt.legend()
    plt.grid(True)
    plt.show()

def main():
    cfg = initialize()
    update_vars(cfg)
    x, y_obs, y_true = generate_data(cfg, cfg.infile)

    for i in range(cfg.max_iter):
        trace = run_mcmc(x, y_obs, cfg)
        post = trace.posterior.stack(samples=("chain","draw"))
        coeffs_mean = [float(post[f"a{i}"].mean().values) for i in range(cfg.p)]
        y_pred = model(x, coeffs_mean, cfg)
        sigma_new = estimate_noise(y_obs, y_pred)
        if abs(sigma_new - cfg.sigma_noise)/cfg.sigma_noise < cfg.tol:
            print(f"収束: iter={i+1}, σ_noise={sigma_new:.5f}")
            cfg.sigma_noise = sigma_new
            break
        cfg.sigma_noise = sigma_new

    summary = extract_summary(trace, cfg)
    print("\n=== パラメータ推定値 ===")
    for name, row in summary.iterrows():
        print(f"  {name}: {row['mean']:.5f} ± {row['sd']:.5f}")
    print("========================\n")

    plot_results(x, y_obs, y_true, trace, cfg)

if __name__ == "__main__":
    main()
