import numpy as np
from numpy import sqrt
import matplotlib.pyplot as plt


# 乱数シード（再現性のため）
#np.random.seed(0)

# ------------------------------------------------------------
# 1. 人工データの生成
# ------------------------------------------------------------
def model(x, ai):
    return ai[0] + ai[1]*x

N = 100  # サンプル数
noise_sigma = 0.3

# 真のパラメータ (beta0: 切片, beta1, beta2)
beta_true = np.array([2.0, 1.5])

# 説明変数 x1, x2 を等間隔に生成
x1 = np.linspace(0, 10, N)
# ノイズを含む目的変数 y
noise = noise_sigma * np.random.normal(loc=0.0, scale=2.0, size=N)
#print("noise:", noise)
y = model(x1, beta_true) + noise

ksigma = 3.0

plot = True

# yの分布を解析的に求める（誤差伝播公式）
def estimate_yerror(beta_est, beta_std, X, y):
    residuals = y - X @ beta_est
    N = X.shape[0]
    p = X.shape[1]

    RSS = residuals.T @ residuals
    variance_y = RSS / (N - p)

    XtX = X.T @ X
    XtX_inv = np.linalg.inv(XtX)
    variance_beta = variance_y * XtX_inv
    beta_std = np.sqrt(np.diag(variance_beta))

    y_mean = X @ beta_est                    # 各xにおけるyの期待値
    y_var  = np.sum(X @ variance_beta * X, axis=1) # 各xにおけるyの分散
    y_std  = np.sqrt(y_var)                  # 各xにおけるyの標準偏差
    
    return y_mean, y_std

def correlation_coefficient(x, y):
    n = len(x)
    avg_x = sum(x) / n
    avg_y = sum(y) / n
    
    Sxx = sum([(xi - avg_x) ** 2 for xi in x])
    Syy = sum([(yi - avg_y) ** 2 for yi in y])
    Sxy = sum([(xi - avg_x) * (yi - avg_y) for xi, yi in zip(x, y)])
    r = Sxy / sqrt(Sxx * Syy)
    return r


def mlsq_line(x1, y):
# 説明変数行列 X: (N×2) [1, x1]
    N = len(x1)
    X = np.column_stack((np.ones(N), x1))

    XtX = X.T @ X
    XtX_inv = np.linalg.inv(XtX)
    beta_est = XtX_inv @ (X.T @ y)

    p = X.shape[1]
    residuals = y - X @ beta_est
    RSS = residuals.T @ residuals
    sigma2 = RSS / (N - p)
    cov_beta = sigma2 * XtX_inv
    beta_std = np.sqrt(np.diag(cov_beta))

    return beta_est, beta_std, X

def mlsq_bayesian(x, y, tau2):
    n = len(x)

    X = np.column_stack((np.ones(n), x))
    
    sigma = 1.0
    for i in range(5):
        precision = (1 / sigma**2) * (X.T @ X + (1 / tau2) * np.eye(X.shape[1]))
        cov_post = np.linalg.inv(precision)
        mean_post = cov_post @ ((1 / sigma**2) * X.T @ y)
        var_a0, var_a1 = np.diag(cov_post)

        a, b = mean_post
        SSE = sum([(y - a - b * x)**2 for x, y in zip(x, y)])
        variance_y = SSE / (n - 2)
        sigma = sqrt(variance_y)
#        print(f"    iter {i}: sigma_nose estimated as sigma_y:", sigma)

    return mean_post, [sqrt(var_a0), sqrt(var_a1)]

def lsq_line(x, y):
    n = len(x)
    if n < 3:
        return None

    avg_x = np.mean(x)
    avg_y = np.mean(y)
    Sxx = sum([(xi - avg_x) ** 2 for xi in x])
    Sxy = sum([(x[i] - avg_x) * (y[i] - avg_y) for i in range(n)])
    Syy = sum([(yi - avg_y) ** 2 for yi in y])

    b = Sxy / Sxx
    a = avg_y - b * avg_x

    # Sum of squared errors and standard deviation (using degrees of freedom n-2)
    SSE = sum([(y[i] - a - b * x[i]) ** 2 for i in range(n)])
    variance_y = SSE / (n - 2)
    sigma_y = sqrt(variance_y)

    sb = sigma_y / sqrt(Sxx)
    sa = sigma_y * sqrt(1/n + avg_x**2 / Sxx)

    r = Sxy / sqrt(Sxx * Syy)

    res = {'SSE': SSE, 'sa': sa, 'sb': sb, 's': sigma_y}
    return a, b, res

def lsq_line2(x, y):
    n = len(x)
    if n < 3:
        return None

    sx = sum(x)
    sy = sum(y)
    avg_x = sx / n
    avg_y = sy / n
    sxx = sum([xi**2 for xi in x])
    sxy = sum([xi * yi for xi, yi in zip(x, y)])
    delta = n * sxx - sx * sx

    b = (n * sxy - sx * sy) / delta
    a = (sxx * sy - sx * sxy) / delta
#    a = avg_y - b * avg_x

    SSE = sum([(y - a - b * x)**2 for x, y in zip(x, y)])
    variance_y = SSE / (n - 2)
    sigma_y = sqrt(variance_y)
    print("lsq_line2(): sigma_y=", sigma_y)

    variance_x = sum([(xi - avg_x)**2 for xi in x]) / n
    sigma_x = sqrt(variance_x)

#    sb = sigma_y / sqrt(n) / sqrt(variance_x)
    sb = sigma_y / sqrt(n) / sqrt(sxx / n - avg_x * avg_x)
    sa = sb * sqrt(sxx / n)
    res = {'SSE': SSE, 'sa': sa, 'sb': sb}
    return a, b, res



def main():
    print()
    print("True parameters:     ", beta_true)
    
    r = correlation_coefficient(x1, y)
    print("Correlation coefficient:", r)

    print()
    print("Analytical errors from Liklihood function mlsq_line():")
    beta_est, beta_std, X = mlsq_line(x1, y)
    print("  Estimated parameters:", beta_est)
    print("  Std error of params: ", beta_std)

    print()
    print("Analytical errors for 2-parameter case lsq_line():")
    a, b, res = lsq_line(x1, y)
    print("  Estimated parameters:", [a, b])
    print("  Std error of params: ", [res['sa'], res['sb']])

    print()
    print("Analytical errors for 2-parameter case lsq_line2():")
    a, b, res = lsq_line2(x1, y)
    print("  Estimated parameters:", [a, b])
    print("  Std error of params: ", [res['sa'], res['sb']])

    print()
    print("Analytical errors by Bayesian regression mlsq_basian():")
    tau2_list = np.logspace(-1.0, 4, 5)
    for tau2 in tau2_list:
        beta_est, beta_std = mlsq_bayesian(x1, y, tau2)
        print("  Cov of prior distrib:", tau2)
        print("    Estimated parameters:", beta_est)
        print("    Std error of params: ", beta_std)

    y_mean, y_std = estimate_yerror(beta_est, beta_std, X, y)


# ------------------------------------------------------------
# 5. 結果の表示（x1順にソートして描画）
# ------------------------------------------------------------
    if not plot: exit()

    idx = np.argsort(x1)
    x1_s = x1[idx]
    y_s  = y[idx]
    y_mean_s = y_mean[idx]
    y_std_s  = y_std[idx]
    line_true = model(x1_s, beta_true)
    line_est  = model(x1_s, beta_est)

    plt.figure(figsize=(10, 6))
    plt.scatter(x1, y, color='blue', label='Data')
    plt.plot   (x1_s, line_true, color='green', label='True line')
    plt.plot   (x1_s, line_est,  color='red',   label='Estimated line')
    plt.fill_between(x1_s, y_mean_s - ksigma * y_std_s, y_mean_s + ksigma * y_std_s,
                 color='gray', alpha=0.3, label=f'{ksigma}σ predictive range')
    plt.title('Distribution of y based on Parameter Uncertainty (Analytical)')
    plt.xlabel('x1')
    plt.ylabel('y')
    plt.legend()
    plt.show()


if __name__ == "__main__":
    main()