import sys
import numpy as np
from scipy.stats import norm

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter
#from matplotlib.animation import FuncAnimation, FFMpegWriter


# mode = "random" / "std" / "max" / "ucb" / "ei"
#mode = "random"
#mode = "std"
#mode = "max"
mode = "ucb"

outfile = ""
#outfile = "gp_std_mode.gif"
nmaxiter = 20

argv = sys.argv
nargs = len(argv)
if nargs >= 2: mode = argv[1]
if nargs >= 3: nmaxiter = int(argv[2])
if nargs >= 4: outfile = argv[3]


plt.rcParams['font.family'] = 'MS Gothic'
plt.rcParams['axes.unicode_minus'] = False

# ==== 真の関数 ====
def f_true(x):
    return np.sin(2 * np.pi * x) + 0.3 * np.cos(4 * np.pi * x)

# ==== RBFカーネル ====
def rbf_kernel(x1, x2, length_scale=0.2, sigma_f=1.0):
    x1 = np.atleast_2d(x1).T
    x2 = np.atleast_2d(x2).T
    d2 = (x1 - x2.T) ** 2
    return sigma_f**2 * np.exp(-0.5 * d2 / length_scale**2)

# ==== ガウス過程回帰 ====
def gp_posterior(x_train, y_train, x_test, noise_sigma,
                 length_scale=0.2, sigma_f=1.0):

    K = rbf_kernel(x_train, x_train, length_scale, sigma_f)
    K_s = rbf_kernel(x_train, x_test, length_scale, sigma_f)
    K_ss = rbf_kernel(x_test, x_test, length_scale, sigma_f)

    K_y = K + (noise_sigma**2) * np.eye(len(x_train))

    L = np.linalg.cholesky(K_y)
    alpha = np.linalg.solve(L.T, np.linalg.solve(L, y_train))

    mu = K_s.T @ alpha

    v = np.linalg.solve(L, K_s)
    cov = K_ss - v.T @ v
    var = np.clip(np.diag(cov), 0, np.inf)

    return mu, var

# ==== 既存点に近すぎる場合のオフセット ====
def avoid_duplicate(x_new, x_train, eps=1e-3):
    if len(x_train) == 0:
        return x_new
    while np.min(np.abs(x_train - x_new)) < eps:
        x_new += np.random.uniform(-0.01, 0.01)
        x_new = np.clip(x_new, 0.0, 1.0)
    return x_new

# ==== メイン ====
def main(mode="random", nmaxiter=40, savefile="gp_animation.mp4"):
    rng = np.random.default_rng(1)

    x_data = []
    y_data = []
    noise_sigma = 0.15

    x_plot = np.linspace(0, 1, 200)
    y_true = f_true(x_plot)

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.set_xlabel("x")
    ax.set_ylabel("y")

    ax.plot(x_plot, y_true, "k--", label="真の関数 f(x)")
    scatter = ax.scatter([], [], color="black", label="観測データ")
    mean_line, = ax.plot([], [], color="C0", label="GP 予測平均")

    uncert_band = None

    ax.legend(loc="upper right")
    ax.set_xlim(0, 1)
    ax.set_ylim(-2.0, 2.0)

    # ---- アニメーション更新 ----
    def update(frame):
        nonlocal uncert_band, x_data, y_data

        # 最初の1点はランダム
        if len(x_data) == 0:
            x_new = rng.uniform(0, 1)
        else:
            x_train = np.array(x_data)
            y_train = np.array(y_data)
            mu, var = gp_posterior(x_train, y_train, x_plot, noise_sigma)
            std = np.sqrt(var)

            if mode == "random":
                x_new = rng.uniform(0, 1)

            elif mode == "std":
                idx = np.argmax(std)
                x_new = x_plot[idx]

            elif mode == "max":
                idx = np.argmax(mu)
                x_new = x_plot[idx]

            elif mode == "ucb":
                kappa = 2.0  # 探索の強さ
                ucb = mu + kappa * std
                idx = np.argmax(ucb)
                x_new = x_plot[idx]

            elif mode == "ei":
                y_best = np.max(y_train)
                improvement = mu - y_best
                Z = improvement / (std + 1e-9)
                ei = improvement * norm.cdf(Z) + std * norm.pdf(Z)
                ei[std < 1e-12] = 0.0
                idx = np.argmax(ei)
                x_new = x_plot[idx]

            else:
                raise ValueError("mode must be 'random', 'std', or 'max'")

        # 既存点と近すぎる場合はオフセット
        x_new = avoid_duplicate(x_new, np.array(x_data))

        # 新しい観測値
        y_new = f_true(x_new) + rng.normal(0.0, noise_sigma)

        x_data.append(x_new)
        y_data.append(y_new)

        # ---- GP posterior 再計算 ----
        x_train = np.array(x_data)
        y_train = np.array(y_data)
        mu, var = gp_posterior(x_train, y_train, x_plot, noise_sigma)
        std = np.sqrt(var)

        scatter.set_offsets(np.column_stack([x_train, y_train]))
        mean_line.set_data(x_plot, mu)

        if uncert_band is not None:
            uncert_band.remove()

        uncert_band = ax.fill_between(
            x_plot,
            mu - 2 * std,
            mu + 2 * std,
            color="C0",
            alpha=0.2
        )

        ax.set_title(f"mode={mode} | データ数 n = {len(x_train)}")
        return scatter, mean_line, uncert_band

    # ---- アニメーション作成 ----
    anim = FuncAnimation(
        fig,
        update,
        frames=nmaxiter,
        interval=500,
        blit=False,
        repeat=False
    )

    # ---- 動画保存 ----
    if savefile is None or savefile == "":
        pass
    else:
    # GIF で保存
        writer = PillowWriter(fps=2)
        anim.save(savefile, writer=writer)
# mp4に保存
#        writer = FFMpegWriter(fps=2)
#        anim.save(savefile, writer=writer)
        print(f"Saved animation to {savefile}")

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main(mode=mode, nmaxiter=nmaxiter, savefile=outfile)
