import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace

# ===== 設定 =====
def initialize():
    cfg = SimpleNamespace()
    cfg.n = 100                         # データ数
    cfg.noise_sigma = 1.0               # 観測ノイズ標準偏差
    cfg.alpha = 1.0                     # 事前精度
    cfg.beta = 1.0 / cfg.noise_sigma**2 # ノイズ精度
    # 比較するモデルの基底選択リスト: 各モデルは基底インデックスのリスト
    # 例: [[0],[0,1],[0,1,2],[0,2]] など任意に設定可能
    # デフォルトは 0..d の多項式系列
    cfg.max_degree = 5
#    cfg.models = [list(range(d+1)) for d in range(cfg.max_degree+1)]
    cfg.models = [
        [1],
        [0, 1],
        [0, 1, 2],
        [0, 1, 2, 3],
        [0, 2, 3],
        [0, 1, 3],
        ]
    cfg.seed = 0                       # シード
    return cfg

# ===== 基底/モデル定義 =====
def basis(i, x):
    return x**i

def model(x, coeffs, basis_indices):
    y = np.zeros_like(x)
    for coef, idx in zip(coeffs, basis_indices):
        y += coef * basis(idx, x)
    return y

# ===== データ生成 =====
def generate_data(cfg):
    np.random.seed(cfg.seed)
    x = np.linspace(-3, 3, cfg.n)
    # 真の係数: 3次多項式
    true_coeffs = [1.0, -2.0, 0.5, 0.1]
    y_true = model(x, true_coeffs, list(range(len(true_coeffs))))
    y_obs = y_true + np.random.normal(0, cfg.noise_sigma, size=cfg.n)
    return x, y_obs, y_true

# ===== 設計行列作成 =====
def design_matrix(x, basis_indices):
    return np.vstack([basis(i, x) for i in basis_indices]).T

# ===== 対数証拠計算 =====
def compute_log_evidence(X, y, alpha, beta):
    N, M = X.shape
    A = alpha * np.eye(M) + beta * (X.T @ X)
    term1 = M * np.log(alpha) / 2
    term2 = N * np.log(beta) / 2
    term3 = - N/2 * np.log(2 * np.pi)
    sign, logdetA = np.linalg.slogdet(A)
    term4 = - 0.5 * logdetA
    term5 = - beta/2 * (y.T @ y - beta * (X.T @ y).T @ np.linalg.solve(A, X.T @ y))
    return term1 + term2 + term3 + term4 + term5

# ===== モデル選択 & 描画 =====
def select_and_plot_models(x, y_obs, y_true, cfg):
    models = cfg.models
    nm = len(models)
    log_evidences = np.zeros(nm)
    # 各モデルごとにログ証拠
    for i, basis_indices in enumerate(models):
        X = design_matrix(x, basis_indices)
        log_evidences[i] = compute_log_evidence(X, y_obs, cfg.alpha, cfg.beta)
        print(f"Model {basis_indices}: log-evidence = {log_evidences[i]:.3f}")

    # 事後モデル確率
    max_le = np.max(log_evidences)
    unnorm = np.exp(log_evidences - max_le)
    post_probs = unnorm / np.sum(unnorm)
    print("\nPosterior probabilities:")
    for basis_indices, p in zip(models, post_probs):
        print(f"  Model {basis_indices}: P = {p:.3f}")

    # 次元配置 for subplots
    ncols = 3
    nrows = int(np.ceil(nm / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 3*nrows), sharex=True, sharey=True)
    axes = axes.flatten()
    x_plot = np.linspace(x.min(), x.max(), 300)

    # プロット
    for i, basis_indices in enumerate(models):
        ax = axes[i]
        # 事後平均 m_N
        X = design_matrix(x, basis_indices)
        M = X.shape[1]
        A = cfg.alpha * np.eye(M) + cfg.beta * (X.T @ X)
        m_N = cfg.beta * np.linalg.solve(A, X.T @ y_obs)
        y_fit = model(x_plot, m_N, basis_indices)

        ax.scatter(x, y_obs, alpha=0.4)
        ax.plot(x_plot, y_true if len(y_true)==len(x_plot) else np.interp(x_plot, x, y_true), 'k-', label='True')
        ax.plot(x_plot, y_fit, 'r--', label='Fit')
        ax.set_title(f"Basis {basis_indices}")
        if i==0:
            ax.legend()

    # 余分な Axes を消去
    for j in range(nm, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()

    best_idx = np.argmax(post_probs)
    print(f"\n>> Best model: {models[best_idx]} (P={post_probs[best_idx]:.3f})")
    return models[best_idx]

# ===== 実行 =====
def main():
    cfg = initialize()
    x, y_obs, y_true = generate_data(cfg)
    select_and_plot_models(x, y_obs, y_true, cfg)

if __name__ == '__main__':
    main()
