import sys
import numpy as np
import matplotlib.pyplot as plt

from tklib.tkutils import terminate, pint, pfloat
from tklib.tkapplication import tkApplication
from tklib.tkvariousdata import tkVariousData
from tklib.tkgraphic.tkplotevent import tkPlotEvent
from tklib.tksci.tkFit import tkFit

# ================================================
# 1. 設定・グローバル変数
# ================================================

# デフォルト設定
DEFAULT_NORDER = 3
DEFAULT_INFILE = 'random-poly.xlsx'
# x, y の列指定は argparse 相当のシンプルな処理で受け取る想定
# 範囲フィルタのデフォルト
DEFAULT_XMIN = -1.0e100
DEFAULT_XMAX = 1.0e100
DEFAULT_XCALMIN = '*'
DEFAULT_XCALMAX = '*'

# フォントサイズ等
FONT_SIZE = 16

# ================================================
# 2. データ読み込み＆前処理
# ================================================
def load_data(infile, xlabel, ylabel, xmin, xmax, app):
    """
    Excel 等からデータを読み込み、x, y のリストを返す。
    - infile: ファイルパス
    - xlabel, ylabel: データ列指定（ラベル名またはインデックス）
    - xmin, xmax: フィルタリングする x 範囲
    - app: tkApplication インスタンス（エラー時 terminate 用）
    戻り値: x_list, y_list, labels (元の列ラベル情報など)
    """
    print(f"Read data from [{infile}]")
    datafile = tkVariousData(infile)
    labels, datalist = datafile.Read_minimum_matrix(close_fp=True, force_numeric=False)
    _xlabel, xin = datafile.FindDataArray(xlabel, flag='i')
    _ylabel, yin = datafile.FindDataArray(ylabel, flag='i')
    if xin is None:
        app.terminate(f"\nError: xlabel [{xlabel}] not found\n", pause=True)
    if yin is None:
        app.terminate(f"\nError: ylabel [{ylabel}] not found\n", pause=True)

    # フィルタリング
    x_list = []
    y_list = []
    for xi, yi in zip(xin, yin):
        if xmin <= xi <= xmax:
            x_list.append(xi)
            y_list.append(yi)
    ndata = len(x_list)
    print(f"  Number of data points after filtering: {ndata}")
    print(f"{labels[0]:12} {labels[1]:12}")
    for xi, yi in zip(x_list, y_list):
        print(f"{xi:12.4g} {yi:12.4g}")
    return x_list, y_list, labels

# ================================================
# 3. 回帰：パラメータ推定部分
# ================================================
def fit_polynomial_least_squares(x, y, order, iPrint=False):
    """
    多項式最小二乗フィット（unweighted）。デザイン行列は Vandermonde 行列を利用。
    - x, y: リストまたは 1D array
    - order: 多項式次数（m）
    - iPrint: True で推定結果をプリント
    戻り値:
      beta_est: 係数ベクトル ndarray, shape (m+1,), [c0, c1, ..., c_m]
      residuals: y_i - y_mean_i の array
      sigma2_resid: 残差分散推定値
      cov_beta: パラメータ共分散行列 ndarray, shape (m+1, m+1)
      y_mean: 各 x_i での予測平均 array (shape (ndata,))
      X_design: デザイン行列 ndarray, shape (ndata, m+1)
    """
    x_arr = np.asarray(x)
    y_arr = np.asarray(y)
    N = len(x_arr)
    m = order

    # Vandermonde: increasing=True で列順 [x^0, x^1, ..., x^m]
    X_design = np.vander(x_arr, N=m+1, increasing=True)  # shape (N, m+1)

    # 正規方程式解
    XtX = X_design.T @ X_design
    try:
        XtX_inv = np.linalg.inv(XtX)
    except np.linalg.LinAlgError:
        raise RuntimeError("Design matrix X^T X is singular; check data or reduce polynomial order.")
    beta_est = XtX_inv @ (X_design.T @ y_arr)  # shape (m+1,)

    # 予測平均と残差
    y_mean = X_design @ beta_est  # shape (N,)
    residuals = y_arr - y_mean    # shape (N,)

    # 自由度: N - (m+1)
    dof = N - (m+1)
    if dof <= 0:
        raise RuntimeError(f"Not enough data points ({N}) for polynomial order {m} (requires > {m} points).")
    RSS = float(residuals.T @ residuals)
    sigma2_resid = RSS / dof

    # パラメータ共分散行列
    cov_beta = sigma2_resid * XtX_inv  # shape (m+1, m+1)
    beta_std = np.sqrt(np.diag(cov_beta))

    if iPrint:
        print("Estimated parameters (beta_est):", beta_est)
        print("Std error of params (beta_std):", beta_std)
        print("Estimated residual variance sigma2_resid:", sigma2_resid)

    return beta_est, residuals, sigma2_resid, cov_beta, y_mean, X_design

# ================================================
# 4. 不確かさ計算
# ================================================
def compute_measurement_error_estimate(y):
    """
    データ y の不偏分散から測定誤差目安 sigma_meas を計算。
    - y: array-like
    戻り値: sigma_meas (float)
    """
    y_arr = np.asarray(y)
    N = len(y_arr)
    if N < 2:
        return 0.0
    var_unbiased = np.var(y_arr, ddof=1)  # 不偏分散
    sigma_meas = np.sqrt(var_unbiased)
    return sigma_meas

def compute_uncertainty_bands(x_points, beta_est, cov_beta, sigma2_resid=None, sigma_meas=None):
    """
    与えられた x_points (array-like) 上での
      - y_mean (予測平均)
      - sigma_param: パラメータ由来不確かさの標準偏差
      - sigma_pred: prediction band 用標準偏差 (パラメータ不確かさ + 残差分散)
      - sigma_combined: combined band 用標準偏差 (パラメータ不確かさ + 測定誤差目安)
    を計算して返す。
    - x_points: 1D array-like, 計算したい x の配列
    - beta_est: ndarray shape (m+1,)
    - cov_beta: ndarray shape (m+1, m+1)
    - sigma2_resid: float or None. None なら prediction band は None とする。
    - sigma_meas: float or None. None なら combined band は None とする。
    戻り値: dict with keys:
      'x': ndarray,
      'y_mean': ndarray,
      'sigma_param': ndarray,
      'sigma_pred': ndarray or None,
      'sigma_combined': ndarray or None
    """
    x_arr = np.asarray(x_points)
    m = len(beta_est) - 1
    # デザイン行列 Xcal: shape (len(x_arr), m+1)
    # np.vander で increasing=True
    Xcal = np.vander(x_arr, N=m+1, increasing=True)
    # 予測平均
    y_mean = Xcal @ beta_est  # shape (len(x_arr),)
    # パラメータ由来分散
    # Var[y_mean(x)] = Xcal @ cov_beta @ Xcal.T の対角成分
    # 計算: 各行 i について Xcal[i] @ cov_beta @ Xcal[i].T
    # まとめて: np.sum((Xcal @ cov_beta) * Xcal, axis=1)
    y_var_param = np.sum((Xcal @ cov_beta) * Xcal, axis=1)  # shape (len(x_arr),)
    sigma_param = np.sqrt(np.clip(y_var_param, 0, None))

    # prediction band
    if sigma2_resid is not None:
        y_var_pred = y_var_param + sigma2_resid
        sigma_pred = np.sqrt(np.clip(y_var_pred, 0, None))
    else:
        sigma_pred = None

    # combined band
    if sigma_meas is not None:
        y_var_comb = y_var_param + sigma_meas**2
        sigma_combined = np.sqrt(np.clip(y_var_comb, 0, None))
    else:
        sigma_combined = None

    return {
        'x': x_arr,
        'y_mean': y_mean,
        'sigma_param': sigma_param,
        'sigma_pred': sigma_pred,
        'sigma_combined': sigma_combined
    }

# ================================================
# 5. プロット関数
# ================================================
def plot_results(x_data, y_data, bands_dict, beta_est, beta_std, sigma2_resid, sigma_meas, xlabel, ylabel):
    """
    プロットを作成する。
    - x_data, y_data: 元のデータ
    - bands_dict: compute_uncertainty_bands で得た dict
    - beta_est, beta_std: 回帰パラメータとその標準誤差
    - sigma2_resid: 残差分散
    - sigma_meas: 測定誤差目安
    - xlabel, ylabel: 軸ラベル
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    plot_event = tkPlotEvent(plt)

    # 左: データとフィット + 各種バンド
    ax = axes[0]
    # データプロット
    ax.plot(x_data, y_data, 'o', label='data')

    # フィット曲線
    ax.plot(bands_dict['x'], bands_dict['y_mean'], '-', label='fit')

    # confidence band (σ_param)
    ax.fill_between(bands_dict['x'],
                    bands_dict['y_mean'] - bands_dict['sigma_param'],
                    bands_dict['y_mean'] + bands_dict['sigma_param'],
                    color='gray', alpha=0.3,
                    label='confidence band (±σ_param)')

    # prediction band (σ_pred) があれば
    if bands_dict['sigma_pred'] is not None:
        ax.fill_between(bands_dict['x'],
                        bands_dict['y_mean'] - bands_dict['sigma_pred'],
                        bands_dict['y_mean'] + bands_dict['sigma_pred'],
                        color='pink', alpha=0.2,
                        label='prediction band (±√(σ_param²+σ_resid²))')

    # combined band (σ_combined) があれば
    if bands_dict['sigma_combined'] is not None:
        ax.fill_between(bands_dict['x'],
                        bands_dict['y_mean'] - bands_dict['sigma_combined'],
                        bands_dict['y_mean'] + bands_dict['sigma_combined'],
                        color='orange', alpha=0.2,
                        label='combined band (±√(σ_param²+σ_meas²))')

    ax.set_xlabel(xlabel, fontsize=FONT_SIZE)
    ax.set_ylabel(ylabel, fontsize=FONT_SIZE)
    ax.legend(fontsize=FONT_SIZE)
    ax.set_title('Fit and Uncertainty Bands')

    # 右: パラメータ推定結果と誤差
    ax2 = axes[1]
    idxs = np.arange(len(beta_est))
    ax2.errorbar(idxs, beta_est, yerr=beta_std, fmt='o', capsize=3, label='coeff ±σ')
    ax2.set_xlabel('$i$', fontsize=FONT_SIZE)
    ax2.set_ylabel('$c_i$', fontsize=FONT_SIZE)
    ax2.legend(fontsize=FONT_SIZE)
    ax2.set_title('Parameter Estimates')

    # tkPlotEvent 登録（必要に応じて）
    plot_event.add_data({"label": "data", "plot_type": "plot", "axis": ax, "data": []})
    plot_event.add_data({"label": "fit",  "plot_type": "plot", "axis": ax, "data": []})
    plot_event.add_data({"label": "coeff", "plot_type": "plot", "axis": ax2, "data": []})
    plot_event.register_event(fig, event="button_press_event",
                              callback=lambda event: plot_event.onclick(event))

    plt.tight_layout()
    plt.pause(0.1)

# ================================================
# 6. メイン処理 & 引数処理
# ================================================
def main():
    app = tkApplication()

    # コマンドライン引数の簡易処理例:
    # python script.py infile.xlsx [order] [xlabel] [ylabel] [xmin] [xmax] [xcalmin] [xcalmax]
    argv = sys.argv
    narg = len(argv)
    infile = DEFAULT_INFILE
    norder = DEFAULT_NORDER
    xlabel = 0
    ylabel = 1
    xmin = DEFAULT_XMIN
    xmax = DEFAULT_XMAX
    xcalmin = DEFAULT_XCALMIN
    xcalmax = DEFAULT_XCALMAX

    if narg >= 2:
        infile = argv[1]
    if narg >= 3:
        norder = pint(argv[2])
    if narg >= 4:
        xlabel = argv[3]
    if narg >= 5:
        ylabel = argv[4]
    if narg >= 6:
        xmin = pfloat(argv[5])
    if narg >= 7:
        xmax = pfloat(argv[6])
    if narg >= 8:
        xcalmin = pfloat(argv[7], defval=xcalmin)
    if narg >= 9:
        xcalmax = pfloat(argv[8], defval=xcalmax)

    # ログ出力先設定など
    logfile = app.replace_path(infile)
    print(f"Open logfile [{logfile}]")
    app.redict(targets=["stdout", logfile], mode='w')

    print(f"Input file: {infile}")
    print(f"Polynomial order: {norder}")
    print(f"xlabel={xlabel}, ylabel={ylabel}")
    print(f"x range filter: {xmin} to {xmax}")
    if xlabel == '' or ylabel == '':
        app.terminate("Error: xlabel/ylabel must be specified\n", pause=True)

    # 2. データ読み込み
    x_data, y_data, labels = load_data(infile, xlabel, ylabel, xmin, xmax, app)
    if len(x_data) == 0:
        app.terminate("Error: No data points after filtering\n", pause=True)

    # 3. y の測定誤差目安 (不偏分散)
    sigma_meas = compute_measurement_error_estimate(y_data)
    print(f"Estimated measurement error (from data variance): sigma_meas = {sigma_meas:g}")

    # 4. 回帰（パラメータ推定）
    beta_est, residuals, sigma2_resid, cov_beta, y_mean_data, X_design = fit_polynomial_least_squares(
        x_data, y_data, norder, iPrint=True)

    # 5. xcal 定義
    xcalmin_val = pfloat(xcalmin, defval=min(x_data))
    xcalmax_val = pfloat(xcalmax, defval=max(x_data))
    # 適宜点数を変更可能
    ncal = 201
    xcal = np.linspace(xcalmin_val, xcalmax_val, ncal)

    # 6. 不確かさバンド計算
    bands = compute_uncertainty_bands(
        x_points=xcal,
        beta_est=beta_est,
        cov_beta=cov_beta,
        sigma2_resid=sigma2_resid,
        sigma_meas=sigma_meas
    )

    # 7. 結果保存（既存処理を活用）
    # ycal for score
    import numpy as _np  # 再利用
    p = np.poly1d(beta_est[::-1])
    ycal_data = p(x_data)
    ycal_cal = p(xcal.tolist())
    fit = tkFit()
    fit.print_scores(heading="\nScores between y(input) and y(fit)", y1=y_data, y2=ycal_data)
    outfile = app.replace_path(infile, template=["{dirname}", "{filebody}-fit.xlsx"])
    print(f"Save results to [{outfile}]")
    fit.to_excel(outfile,
                 [xlabel, ylabel, f"{ylabel}(fit)", "", xlabel, f"{ylabel}(cal)"],
                 [x_data, y_data, ycal_data, [], xcal.tolist(), ycal_cal])

    # 8. プロット
    plot_results(x_data, y_data, bands, beta_est, np.sqrt(np.diag(cov_beta)), sigma2_resid, sigma_meas, xlabel, ylabel)

    app.terminate("", pause=True)


if __name__ == "__main__":
    main()
