# bayesian_linear_regression_incremental_general.py

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter


# === Parameters ===
mode = "plot"
#mode = "save"

# total data points, initial points, batch size, number of parameters
n_total   = 100    
ninitial  = 2      
nadd      = 3      
nparams   = 4      

# basis functions: x^0, x^1, ..., x^{nparams-2}, and sqrt(x)
def make_basis(nparams):
    funcs = [lambda x, i=i: x**i for i in range(nparams-1)]
    funcs.append(np.sqrt)
    return funcs

basis_funcs = make_basis(nparams)

# True parameters for data generation
true_params = np.array([1.0, 2.0, 0.5, -1.0])

# Noise settings
sigma_noise = 0.2
sigma2 = sigma_noise**2
seed   = 0

# === Data generation ===
np.random.seed(seed)
x_all = np.random.uniform(0, 1, size=n_total)
y_all = sum(true_params[i] * basis_funcs[i](x_all)
            for i in range(nparams)) + np.random.normal(0, sigma_noise, size=n_total)

# === Prior distribution ===
mu0    = np.zeros(nparams)
Sigma0 = np.eye(nparams) * 1e3

# === Bayesian update function (single point) ===
def bayesian_linear_update(mu_prev, Sigma_prev, x_new, y_new):
    phi = np.array([f(x_new) for f in basis_funcs])
    invS = np.linalg.inv(Sigma_prev)
    Sigma_new = np.linalg.inv(invS + np.outer(phi, phi) / sigma2)
    mu_new    = Sigma_new @ (invS @ mu_prev + phi * (y_new / sigma2))
    return mu_new, Sigma_new

# === Initial update with ninitial points ===
mu_current    = mu0.copy()
Sigma_current = Sigma0.copy()
for xi, yi in zip(x_all[:ninitial], y_all[:ninitial]):
    mu_current, Sigma_current = bayesian_linear_update(mu_current, Sigma_current, xi, yi)

# Records for plotting
data_counts = [ninitial]
mu_list      = [mu_current.copy()]
std_list     = [np.sqrt(np.diag(Sigma_current))]

# === Plot setup ===
plt.rcParams['font.family']        = 'MS Gothic'
plt.rcParams['axes.unicode_minus'] = False

fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(14, 5))
plt.subplots_adjust(wspace=0.3)

x_plot   = np.linspace(0, 1, 200)
Phi_plot = np.vstack([f(x_plot) for f in basis_funcs]).T  

# === init function ===
def init():
    # Left plot
    ax_left.clear()
    ax_left.scatter(x_all[:ninitial], y_all[:ninitial], color='black', label='Data')
    y_mean = Phi_plot @ mu_current
    var_s  = np.sum(Phi_plot @ Sigma_current * Phi_plot, axis=1)
    var_t  = var_s + sigma2
    ax_left.plot(x_plot, y_mean, color='blue', label='Mean')
    ax_left.fill_between(x_plot, y_mean - 1.96 * np.sqrt(var_s), y_mean + 1.96 * np.sqrt(var_s),
                         color='orange', alpha=0.3, label='95% CI (structure)')
    ax_left.fill_between(x_plot, y_mean - 1.96 * np.sqrt(var_t), y_mean + 1.96 * np.sqrt(var_t),
                         color='blue', alpha=0.2, label='95% CI (total)')
    ax_left.set_xlim(0, 1); ax_left.set_ylim(-0.5, 3.5)
    ax_left.set_xlabel('x', fontsize=16)
    ax_left.set_ylabel('y', fontsize=16)
    ax_left.set_title(f'Data count {ninitial}', fontsize=16)
    ax_left.tick_params(axis='both', labelsize=16)
    ax_left.legend()

    # Right plot
    ax_right.clear()
    xs = np.array(data_counts)
    for i in range(nparams):
        # mean trace
        ys_mean = [mu[i] for mu in mu_list]
        ax_right.plot(xs, ys_mean, label=f'a{i} mean')
        # std trace
        ys_std = [std_vals[i] for std_vals in std_list]
        ax_right.plot(xs, ys_std, '--', label=f'a{i} std')
    ax_right.set_xlim(ninitial, n_total)
    ax_right.set_ylim(-0.5, 4.0)
    ax_right.set_xlabel('Data count', fontsize=16)
    ax_right.set_ylabel('Estimate / std', fontsize=16)
    ax_right.set_title('Parameter estimates', fontsize=16)
    ax_right.tick_params(axis='both', labelsize=16)
    ax_right.legend()
    return []

# === update function ===
def update(frame):
    global mu_current, Sigma_current
    start = ninitial + frame * nadd
    end   = min(ninitial + (frame + 1) * nadd, n_total)

    for xi, yi in zip(x_all[start:end], y_all[start:end]):
        mu_current, Sigma_current = bayesian_linear_update(mu_current, Sigma_current, xi, yi)

    # Left plot update
    ax_left.clear()
    ax_left.scatter(x_all[:end], y_all[:end], color='black')
    y_mean = Phi_plot @ mu_current
    var_s  = np.sum(Phi_plot @ Sigma_current * Phi_plot, axis=1)
    var_t  = var_s + sigma2
    ax_left.plot(x_plot, y_mean, color='blue')
    ax_left.fill_between(x_plot, y_mean - 1.96 * np.sqrt(var_s), y_mean + 1.96 * np.sqrt(var_s),
                         color='orange', alpha=0.3)
    ax_left.fill_between(x_plot, y_mean - 1.96 * np.sqrt(var_t), y_mean + 1.96 * np.sqrt(var_t),
                         color='blue', alpha=0.2)
    ax_left.set_xlim(0, 1); ax_left.set_ylim(-0.5, 3.5)
    ax_left.set_xlabel('x', fontsize=16)
    ax_left.set_ylabel('y', fontsize=16)
    ax_left.set_title(f'Data count {end}', fontsize=16)
    ax_left.tick_params(axis='both', labelsize=16)

    # Right plot update
    data_counts.append(end)
    mu_list.append(mu_current.copy())
    std_list.append(np.sqrt(np.diag(Sigma_current)))

    ax_right.clear()
    xs = np.array(data_counts)
    for i in range(nparams):
        ys_mean = [mu[i] for mu in mu_list]
        ax_right.plot(xs, ys_mean, label=f'a{i} mean')
        ys_std = [std_vals[i] for std_vals in std_list]
        ax_right.plot(xs, ys_std, '--', label=f'a{i} std')
    ax_right.set_xlim(ninitial, n_total)
    ax_right.set_ylim(-0.5, 4.0)
    ax_right.set_xlabel('Data count', fontsize=16)
    ax_right.set_ylabel('Estimate / std', fontsize=16)
    ax_right.set_title('Parameter estimates', fontsize=16)
    ax_right.tick_params(axis='both', labelsize=16)
    ax_right.legend()

    return []

# === Run animation ===
n_frames = int(np.ceil((n_total - ninitial) / nadd))
anim = FuncAnimation(fig, update, frames=range(n_frames), init_func=init,
                     interval=800, repeat=False, blit=False)

if mode == "save":
# Save animation as GIF
    print("Saving animation as 'bayesian_linear_animation.gif'")
# interval is in ms, so fps = 1000 / interval
    writer = PillowWriter(fps=1000/800)
    anim.save('bayesian_linear_animation.gif', writer=writer)
    print("Saved")
else:
# Display animation
    print("Displaying animation")
    plt.show()

# === Display results table ===
records = []
for count, mu, std in zip(data_counts, mu_list, std_list):
    rec = {'n': count}
    for i in range(nparams):
        rec[f'mu_{i}']  = mu[i]
        rec[f'std_{i}'] = std[i]
    records.append(rec)

df = pd.DataFrame(records)
print(df.to_string(index=False))
