import os
import numpy as np
import multiprocessing
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter


# 設定項目
# === Parameters ===
#mode = "plot"
mode = "save"

nparams = 4  # 回帰パラメータの数（a0, a1, a2, a3）
sigma_noise = 3.0  
true_params = [10.0, 5.0, 2.0, 1.0]  # [a0, a1, a2, a3]

ntune = 100  # チューニング期間
nsamples = 400  # 最初のサンプル数
ninitial = 20  # 初期サンプル数
nadd = 50  # サンプリング追加数
acceptance = 0.9

# cpuのコア数を取得
try:
    ncores = multiprocessing.cpu_count()
except NotImplementedError:
    ncores = os.cpu_count()
if ncores is None:
    ncores = 1

# for debug（Windowsでは1にすべき）
#ncores = 1

print()
print(f"ncores={ncores}")

# matplotlibでMS Gothicを使用（日本語ラベルのため） 
plt.rcParams['font.family'] = 'MS Gothic'
plt.rcParams['axes.unicode_minus'] = False  # マイナス記号の表示を有効にする

# フォントサイズを16に設定
font_size = 16
plt.rcParams.update({'font.size': font_size})

def main():
    # 1. データ生成
    np.random.seed(0)
    x = np.linspace(0, 1, 20)
    y = true_params[0] + true_params[1] * np.exp(true_params[2] * x) + true_params[3] * x**2 + np.random.normal(0, sigma_noise, size=x.shape)

    # 2. PyMCモデルの定義
    with pm.Model() as model:
        # パラメータの定義
        params = [pm.Normal(f"a{i}", mu=0, sigma=10) for i in range(nparams)]  # a0, a1, a2, a3
        sigma = pm.HalfNormal("sigma", sigma=1)

        # モデルの予測
        mu = params[0] + params[1] * pm.math.exp(params[2] * x) + params[3] * x**2
        y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

        # 最初のサンプリングを少なくとも200回行う
#        trace = pm.sample(cores=ncores, draws=nsamples, tune=ntune, return_inferencedata=True, target_accept=acceptance)
        trace = None

    # アニメーションのためにFigureとAxesを準備
    fig, (ax, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    ax.scatter(x, y, label="観測データ", color="black")

    # 理論的な回帰モデルの曲線を描画
    x_plot = np.linspace(0, 1, 100)
    y_theoretical = true_params[0] + true_params[1] * np.exp(true_params[2] * x_plot) + true_params[3] * x_plot**2
    ax.plot(x_plot, y_theoretical, label="理論モデル", color="blue", linestyle="--")

    # 回帰結果の予測線（最初は空）
    line, = ax.plot([], [], label="事後平均による予測", color="red")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title("ベイズ回帰モデルの予測（事後平均）")
    ax.legend()
    ax.grid(True)

    # 回帰パラメータの統計量のプロット
    ax2.set_xlabel("サンプル数")
    ax2.set_ylabel("パラメータのmean & std")
    ax2.set_xlim(0, 1000)
    ax2.set_ylim(-10, 20)

    # nsamples_usedを管理
    nsamples_used = ninitial  # 初期サンプル数

    # 初期化関数（アニメーション用）
    def init():
        line.set_data([], [])
        return line,

    # 更新関数（アニメーション用）
    nsamples_list = []
    mean_list = []
    std_list = []
    def update(frame):
        nonlocal nsamples_used, trace  # traceとnsamples_usedを更新するためにnonlocalを使う
        
        # nsamples_usedがnsamplesを超えた場合、アニメーションを停止
        if nsamples_used >= nsamples:  
            return line,

        print(f"frame {frame}, nsample={nsamples_used}")

        # 逐次的にサンプリング（frame回目まで）
        with model:
            trace = pm.sample(cores=ncores, draws=nsamples_used, tune=ntune, return_inferencedata=True, target_accept=acceptance)

        # x_plot を定義（予測のため）
        x_plot = np.linspace(0, 1, 100)

        # 現在のフレームでパラメータの事後平均を更新
        params_samples = [trace.posterior[f"a{i}"].stack(sample=("chain", "draw")).values[:nsamples_used] for i in range(nparams)]

        # 総サンプル数（全チェーンを含む）の取得
        total_samples = nsamples_used
        ax.set_title(f"回帰に使用したサンプル数: {total_samples}")

        # 予測値を計算
        y_preds = np.array([params_samples[0][i] + params_samples[1][i] * np.exp(params_samples[2][i] * x_plot) + params_samples[3][i] * x_plot**2 for i in range(nsamples_used)])
        y_mean_pred = y_preds.mean(axis=0)
        
        # 信頼区間（1σ）を計算
        y_std_pred = y_preds.std(axis=0)
        y_lower = y_mean_pred - y_std_pred
        y_upper = y_mean_pred + y_std_pred

        # 以前の信頼区間を削除
        for collection in ax.collections:
            collection.remove()

        # 更新された予測線を設定
        line.set_data(x_plot, y_mean_pred)

        # 信頼区間（1σ）を塗りつぶす
        ax.fill_between(x_plot, y_lower, y_upper, color='red', alpha=0.2, label="1σ 信頼区間")

        ax.scatter(x, y, label="観測データ", color="black")

        # パラメータのmeanとstdを計算
        param_means = [np.mean(params_samples[i]) for i in range(nparams)]
        param_stds = [np.std(params_samples[i]) for i in range(nparams)]
        print(f"{param_means=}")
        print(f"{param_stds=}")
        nsamples_list.append(total_samples)
        mean_list.append(param_means)
        std_list.append(param_stds)

        # プロット更新
        ax2.clear()
        ax2.set_xlabel("サンプル数")
        ax2.set_ylabel("パラメータのmean & std")

        # nsamples_list、mean_list、std_list に基づいて ylim を設定
        ax2.set_xlim(0, total_samples)
        ax2.set_ylim(np.min(mean_list + std_list) - 2, np.max(mean_list + std_list) + 2)  # ylimを動的に設定
        
        ax2.plot(nsamples_list, mean_list, label="mean")
        ax2.plot(nsamples_list, std_list, linestyle="--", label="std")
        ax2.legend(fontsize=8)

        plt.pause(0.1)

        # nsamples_usedを更新
        nsamples_used += nadd

        return line,

    # アニメーションの作成
    anim = FuncAnimation(fig, update, frames=range(100, 2000, 100), 
                        init_func=init, blit=True, interval=200)

    if mode == "save":
# Save animation as GIF
        print("Saving animation as 'bayesian_nonlinear_animation.gif'")
# interval is in ms, so fps = 1000 / interval
        writer = PillowWriter(fps=1000/800)
        anim.save('bayesian_nonlinear_animation.gif', writer=writer)
        print("Saved")
    else:
# Display animation
        print("Displaying animation")
        plt.show()


# Windows向けに必須のガード
if __name__ == '__main__':
    main()
    input("Press Enter to exit...")
