import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
from sklearn.linear_model import ARDRegression

# ===== Initialization =====
def initialize():
    cfg = SimpleNamespace()
    cfg.n = 100                          # Number of data points
    cfg.noise_sigma = 1.0                # Observation noise standard deviation
    cfg.alpha = 1.0                      # Prior precision for weights
    cfg.beta = 1.0 / cfg.noise_sigma**2  # Observation precision
    cfg.max_degree = 5                   # Maximum basis index
    # Models to compare: polynomial basis indices
    cfg.models = [list(range(d+1)) for d in range(cfg.max_degree+1)]
    cfg.seed = 0                         # Random seed
    cfg.sparse = True                    # Enable sparse prior (ARD)
    # ARD hyperparameters (Gamma prior shape and rate)
    cfg.lambda_1 = 1e-6                  # weight precision prior shape
    cfg.lambda_2 = 1e-6                  # weight precision prior rate
    cfg.alpha_1 = 1e-6                   # noise precision prior shape
    cfg.alpha_2 = 1e-6                   # noise precision prior rate
    return cfg

# ===== Basis and model =====
def basis(i, x):
    return x**i

def model(x, coeffs, basis_indices):
    y = np.zeros_like(x)
    for coef, idx in zip(coeffs, basis_indices):
        y += coef * basis(idx, x)
    return y

# ===== Data generation =====
def generate_data(cfg):
    np.random.seed(cfg.seed)
    x = np.linspace(-3, 3, cfg.n)
    true_coeffs = [1.0, -2.0, 0.5, 0.1]  # True polynomial
    basis_true = list(range(len(true_coeffs)))
    y_true = model(x, true_coeffs, basis_true)
    y_obs = y_true + np.random.normal(0, cfg.noise_sigma, size=cfg.n)
    return x, y_obs, y_true, basis_true, true_coeffs

# ===== Design matrix =====
def design_matrix(x, basis_indices):
    return np.vstack([basis(i, x) for i in basis_indices]).T

# ===== Log evidence =====
def compute_log_evidence(X, y, alpha, beta):
    N, M = X.shape
    A = alpha * np.eye(M) + beta * (X.T @ X)
    term1 = M * np.log(alpha) / 2
    term2 = N * np.log(beta) / 2
    term3 = -N/2 * np.log(2 * np.pi)
    sign, logdetA = np.linalg.slogdet(A)
    term4 = -0.5 * logdetA
    term5 = -beta/2 * (y.T @ y - beta * (X.T @ y).T @ np.linalg.solve(A, X.T @ y))
    return term1 + term2 + term3 + term4 + term5

# ===== Baseline selection =====
def select_models(x, y_obs, cfg):
    models, log_evidences = cfg.models, []
    for basis_indices in models:
        X = design_matrix(x, basis_indices)
        log_evidences.append(compute_log_evidence(X, y_obs, cfg.alpha, cfg.beta))
    log_evidences = np.array(log_evidences)
    max_le = log_evidences.max()
    post_probs = np.exp(log_evidences - max_le)
    post_probs /= post_probs.sum()
    print("\nBaseline model selection:")
    for b, le, p in zip(models, log_evidences, post_probs):
        print(f" Model {b}: log-evidence={le:.3f}, P={p:.3f}")
    best_idx = post_probs.argmax()
    best_model = models[best_idx]
    print(f">> Best baseline model: {best_model} (P={post_probs[best_idx]:.3f})\n")
    return models, log_evidences, post_probs, best_model

# ===== ARD and comparison =====
def run_ard_and_compare(x, y_obs, y_true, true_coeffs, basis_true, models, log_evidences, post_probs, best_model, cfg):
    full_basis = list(range(cfg.max_degree+1))
    X_full = design_matrix(x, full_basis)
    # Pass ARD hyperparameters
    ard = ARDRegression(lambda_1=cfg.lambda_1,
                        lambda_2=cfg.lambda_2,
                        alpha_1=cfg.alpha_1,
                        alpha_2=cfg.alpha_2,
                        compute_score=True)
    ard.fit(X_full, y_obs)
    coeffs_ard = ard.coef_
    selected = [i for i, c in enumerate(coeffs_ard) if abs(c) > 1e-3]
    print(f"[ARD] Selected basis indices: {selected}")
    print(f"[ARD] Coefficients: {coeffs_ard}")
    # Evidence of ARD model
    X_sel = design_matrix(x, selected)
    le_sel = compute_log_evidence(X_sel, y_obs, cfg.alpha, cfg.beta)
    print(f"[ARD] log-evidence selected model = {le_sel:.3f}")
    if selected in models:
        idx = models.index(selected)
        print(f" baseline P={post_probs[idx]:.3f}, baseline log-evidence={log_evidences[idx]:.3f}\n")
    else:
        print("[ARD] Selected model not in baseline list\n")
    # Plot baseline vs ARD and all baseline fits
    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)
    x_plot = np.linspace(x.min(), x.max(), 300)
    # True function
    y_true_plot = model(x_plot, true_coeffs, basis_true)
    # Baseline fit
    Ab = cfg.alpha * np.eye(len(best_model)) + cfg.beta * (design_matrix(x, best_model).T @ design_matrix(x, best_model))
    mb = cfg.beta * np.linalg.solve(Ab, design_matrix(x, best_model).T @ y_obs)
    y_baseline = model(x_plot, mb, best_model)
    # ARD fit
    y_ard = model(x_plot, coeffs_ard, full_basis)
    ax = axes[0]
    ax.scatter(x, y_obs, alpha=0.4, label='Observed')
    ax.plot(x_plot, y_true_plot, 'k-', label='True')
    ax.plot(x_plot, y_baseline, 'r--', label=f'Baseline {best_model}')
    ax.plot(x_plot, y_ard, 'g-.', label=f'ARD {selected}')
    ax.set_xlabel('x'); ax.set_ylabel('y'); ax.legend(); ax.set_title('Baseline vs ARD')
    ax.grid(True)
    ax2 = axes[1]
    for b, p in zip(models, post_probs):
        Ab_all = cfg.alpha * np.eye(len(b)) + cfg.beta * (design_matrix(x, b).T @ design_matrix(x, b))
        mb_all = cfg.beta * np.linalg.solve(Ab_all, design_matrix(x, b).T @ y_obs)
        y_fit_all = model(x_plot, mb_all, b)
        ax2.plot(x_plot, y_fit_all, label=f'{b}, w={p:.2f}')
    ax2.scatter(x, y_obs, alpha=0.3, color='gray')
    ax2.plot(x_plot, y_true_plot, 'k-', linewidth=2)
    ax2.set_xlabel('x'); ax2.set_title('All Baseline Models'); ax2.legend(fontsize='small', loc='upper right'); ax2.grid(True)
    plt.tight_layout(); plt.show()

# ===== Main =====
def main():
    cfg = initialize()
    x, y_obs, y_true, basis_true, true_coeffs = generate_data(cfg)
    models, log_evidences, post_probs, best_model = select_models(x, y_obs, cfg)
    if cfg.sparse:
        run_ard_and_compare(x, y_obs, y_true, true_coeffs, basis_true, models, log_evidences, post_probs, best_model, cfg)

if __name__ == '__main__':
    main()
