import sys
import argparse
import types
import numpy as np
from math import pow
import matplotlib.pyplot as plt

from tklib.tkutils import terminate, pint, pfloat
from tklib.tkapplication import tkApplication
from tklib.tkvariousdata import tkVariousData
from tklib.tkgraphic.tkplotevent import tkPlotEvent
from tklib.tksci.tkFit import tkFit


def initialize():
    """
    Parse command-line arguments and initialize configuration namespace.
    Returns:
        cfg (types.SimpleNamespace): Configuration with attributes.
    """

    app = tkApplication()
    parser = argparse.ArgumentParser(description='Polynomial least-squares fitting with uncertainty visualization')
    parser.add_argument('--infile', type=str, default='random-poly.xlsx', help='Input data file')
    parser.add_argument('--xlabel', type=str, default='0', help='Label or index for x data')
    parser.add_argument('--ylabel', type=str, default='1', help='Label or index for y data')
    parser.add_argument('--norder', type=int, default=3, help='Polynomial order')
    parser.add_argument('--xmin', type=float, default=-1.0e100, help='Minimum x for fitting')
    parser.add_argument('--xmax', type=float, default=1.0e100, help='Maximum x for fitting')
    parser.add_argument('--xcalmin', type=str, default=None, help='Minimum x for calculation grid')
    parser.add_argument('--xcalmax', type=str, default=None, help='Maximum x for calculation grid')
    parser.add_argument('--ncal', type=int, default=201, help='Number of points in xcal grid')
    parser.add_argument('--xlsm_template', type=str, default="StandardGraph.xlsm", help='Excel template path')
    parser.add_argument('--figsize', type=float, nargs=2, default=[8, 8], help='プロット figsize, 例: --figsize 8 8')
    parser.add_argument('--plot_ci', type=int, default=1, help='Plot confidence interval')
    parser.add_argument('--plot_sigma_param', type=int, default=1, help='Plot parameter uncertainty band')
    parser.add_argument('--plot_sigma_pred', type=int, default=1, help='Plot prediction uncertainty band')
    parser.add_argument('--plot_sigma_combined', type=int, default=0, help='Plot combined uncertainty band')
    parser.add_argument('--fontsize', type=int, default=16, help='Font size for plots')
    parser.add_argument('--fontsize_legend', type=int, default=12, help='Legend font size')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--pause', dest='pause', action='store_true', help='Pause on terminate')
    group.add_argument('--no-pause', dest='pause', action='store_false', help='Do not pause on terminate')
    parser.set_defaults(pause=True)

    args = parser.parse_args()
    cfg = types.SimpleNamespace()
    for key in vars(args):
        setattr(cfg, key, getattr(args, key))

    cfg.output_fitting_path = app.replace_path(cfg.infile, template = ["{dirname}", "{filebody}-fit.xlsm"])
    cfg.output_parameter_path = app.replace_path(cfg.infile, template = ["{dirname}", "{filebody}-parameters.xlsx"])

    return app, cfg, parser

def build_design_matrix(x, order):
    """
    Construct the design matrix for polynomial regression.

    Args:
        x (array-like): Input data points.
        order (int): Polynomial order.
    Returns:
        np.ndarray: (N, order+1) design matrix.
    """
    x = np.asarray(x)
    N = len(x)
    X = np.ones((N, order + 1))
    for j in range(1, order + 1):
        # Using scalar pow for clarity
        X[:, j] = [pow(val, j) for val in x]
    return X

def mlsq_error(X, y):
    """
    Perform least squares fitting.

    Args:
        X (np.ndarray): Design matrix (N, p).
        y (array-like): Observed values (N,).
    Returns:
        beta (np.ndarray): Estimated coefficients (p,).
        cov_beta (np.ndarray): Covariance matrix of parameters (p, p).
        sigma2_resid (float): Residual variance estimate.
    """
    y = np.asarray(y)
    XtX = X.T @ X
    XtX_inv = np.linalg.inv(XtX)
    beta = XtX_inv @ (X.T @ y)

    # Residuals and variance
    residuals = y - X @ beta
    N, p = X.shape
    RSS = float(residuals.T @ residuals)
    sigma2_resid = RSS / (N - p)
    cov_beta = sigma2_resid * XtX_inv
    beta_std = np.sqrt(np.diag(cov_beta))
    return beta, beta_std, cov_beta, sigma2_resid

def compute_param_uncertainty(X, cov_beta):
    """
    Compute parameter-based prediction variance at each X row.

    Args:
        X (np.ndarray): Design matrix (N, p).
        cov_beta (np.ndarray): Covariance matrix (p, p).
    Returns:
        y_var (np.ndarray): Variance of predictions (N,).
    """
    # Var[y_mean] = diag(X @ cov_beta @ X.T)
    return np.sum((X @ cov_beta) * X, axis=1)

def compute_measurement_error(y):
    """
    Estimate measurement error from input data variance.

    Args:
        y (array-like): Observed values.
    Returns:
        float: Estimated measurement std (unbiased).
    """
    y = np.asarray(y)
    var_unbiased = np.var(y, ddof=1) if len(y) > 1 else 0.0
    return np.sqrt(var_unbiased)

def compute_bands(xcal, beta, cov_beta, sigma2_resid, sigma_meas):
    """
    Compute mean predictions and various uncertainty bands on xcal.

    Args:
        xcal (array-like): Points to evaluate.
        beta (np.ndarray): LSQ coefficients.
        cov_beta (np.ndarray): Parameter covariance matrix.
        sigma2_resid (float): Residual variance.
        sigma_meas (float): Measurement error std.
    Returns:
        dict: y_mean, sigma_param, sigma_pred, sigma_combined arrays.
    """
    Xcal = build_design_matrix(xcal, len(beta) - 1)
    y_mean = Xcal @ beta
    var_param = compute_param_uncertainty(Xcal, cov_beta)
    sigma_param = np.sqrt(var_param)
    sigma_pred = np.sqrt(var_param + sigma2_resid)
    sigma_combined = np.sqrt(var_param + sigma_meas**2)
    return {
        'y_mean': y_mean,
        'sigma_param': sigma_param,
        'sigma_pred': sigma_pred,
        'sigma_combined': sigma_combined
    }

def load_data(infile, xlabel, ylabel, xmin, xmax, pause_flag):
    """
    Load data from infile using tkVariousData, filter by xmin/xmax.

    Args:
        infile (str): File path.
        xlabel, ylabel: labels or indices.
        xmin, xmax (float): range filter.
        pause_flag (bool): for terminate behavior.
    Returns:
        dict with 'x', 'y', and raw arrays if needed.
    """
    datafile = tkVariousData(infile)
    labels, _ = datafile.Read_minimum_matrix(close_fp=True, force_numeric=False)
    _xlabel, xin = datafile.FindDataArray(xlabel, flag='i')
    _ylabel, yin = datafile.FindDataArray(ylabel, flag='i')
    if xin is None:
        terminate(f"Error: xlabel [{xlabel}] not found", pause=pause_flag)
    if yin is None:
        terminate(f"Error: ylabel [{ylabel}] not found", pause=pause_flag)
    x = []
    y = []
    for xi, yi in zip(xin, yin):
        if xmin <= xi <= xmax:
            x.append(xi)
            y.append(yi)
    if len(x) == 0:
        terminate("Error: No data in specified x-range", pause=pause_flag)
    return {'x': x, 'y': y}

# =====================================================================
# Section 6: Main Execution and Plotting
# =====================================================================
def main():
    app, cfg, parser = initialize()
 
    # Open logfile if desired
    logfile = app.replace_path(cfg.infile)
    print()
    print(f"Open logfile [{logfile}]")
    app.redict(targets=["stdout", logfile], mode='w')

    # Parse xlabel/ylabel: allow numeric index or string label
    xlabel = cfg.xlabel
    ylabel = cfg.ylabel
    # xmin/xmax are floats already
    print()
    print(f"Configuration: infile={cfg.infile}, xlabel={xlabel}, ylabel={ylabel}, norder={cfg.norder}")
    print(f"Fitting range: {cfg.xmin} to {cfg.xmax}")

    # Load data
    data = load_data(cfg.infile, xlabel, ylabel, cfg.xmin, cfg.xmax, cfg.pause)
    x = data['x']
    y = data['y']
    ndata = len(x)
    print(f"ndata = {ndata}")
    for xi, yi in zip(x, y):
        print(f"{xi:12.4g} {yi:12.4g}")

    # Fit parameters
    X = build_design_matrix(x, cfg.norder)
    beta, beta_std, cov_beta, sigma2_resid = mlsq_error(X, y)
    ycal = X @ beta

    print("Fitted polynomial coefficients:")
    for i, coef in enumerate(beta):
        print(f"  c{i} = {coef:g} +- {beta_std[i]:g}")
    print(f"  Residual variance sigma2_resid = {sigma2_resid:g}")

    print(f"Save parameters and stds to {cfg.output_parameter_path}")
    fit = tkFit()
    fit.to_excel(cfg.output_parameter_path, ["coeff", "std"], [beta, beta_std])

    if cfg.xcalmin == '*' or cfg.xcalmin is None:
        xcalmin = min(x)
    else:
        xcalmin = cfg.xcalmin
    if cfg.xcalmax == '*' or cfg.xcalmax is None:
        xcalmax = max(x)
    else:
        xcalmax = cfg.xcalmax

    xcal = np.linspace(xcalmin, xcalmax, cfg.ncal)
    print()
    print(f"Calculation grid: {xcalmin} to {xcalmax}, points = {cfg.ncal}")

    sigma_meas = compute_measurement_error(y)
    print(f"Estimated measurement error sigma_meas = {sigma_meas:g}")
    bands = compute_bands(xcal, beta, cov_beta, sigma2_resid, sigma_meas)

    print("")
    print(f"Save results to [{cfg.output_fitting_path}]")
    fit.to_excel(cfg.output_fitting_path, 
                 [xlabel, ylabel, f"{ylabel}(fit)", "", 
                  f"{xlabel}(cal)", f"{ylabel}(mean)", f"{ylabel}(std(param))", f"{ylabel}(std(param&resid)", f"{ylabel}(std(param&noise)"], 
                 [x, y, ycal, [], xcal, bands['y_mean'], bands['sigma_param'], bands['sigma_pred'], bands['sigma_combined']],
                 template = cfg.xlsm_template)

    # Plot
    if cfg.plot_ci:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))
    else:
        fig, ax1 = plt.subplots(1, 1, figsize=(5,5))
        ax2 = None

    plot_event = tkPlotEvent(plt)

    # Left: data, fit, bands
    ax1.plot(x, y, 'o', label='data', color = 'black', markersize=1.5)
    ax1.plot(xcal, bands['y_mean'], '-', label='fit', linewidth = 0.5, color = 'red')
    if cfg.plot_sigma_param:
        ax1.fill_between(xcal,
                     bands['y_mean'] - bands['sigma_param'],
                     bands['y_mean'] + bands['sigma_param'],
                     color='#0000cc', alpha=0.5,
                     label='±σ(param)')
    if cfg.plot_sigma_pred:
        ax1.fill_between(xcal,
                     bands['y_mean'] - bands['sigma_pred'],
                     bands['y_mean'] + bands['sigma_pred'],
                     color='#ddddFF', alpha=0.5,
                     label='±σ(param&resid)')
    if cfg.plot_sigma_combined:
        ax1.fill_between(xcal,
                         bands['y_mean'] - bands['sigma_combined'],
                         bands['y_mean'] + bands['sigma_combined'],
                         color='purple', alpha=0.5,
                         label='±σ(param+noize)')

    ax1.set_xlabel(cfg.xlabel, fontsize=cfg.fontsize)
    ax1.set_ylabel(cfg.ylabel, fontsize=cfg.fontsize)
    ax1.tick_params(labelsize=cfg.fontsize)
    ax1.legend(fontsize=cfg.fontsize_legend)

    if ax2:
        idxs = np.arange(len(beta))
        ax2.errorbar(idxs, beta, yerr=beta_std, fmt='o', capsize=3, label='coeff ±1σ')
        ax2.set_xlabel('$i$', fontsize=cfg.fontsize)
        ax2.set_ylabel('$c_i$', fontsize=cfg.fontsize)
        ax2.tick_params(labelsize=cfg.fontsize)
        ax2.legend(fontsize=cfg.fontsize_legend)

    plot_event.register_event(fig, event='button_press_event',
                              callback=lambda e: plot_event.onclick(e))
    plt.tight_layout()
    plt.pause(0.1)

    # Terminate application
    app.terminate("", pause=cfg.pause)

if __name__ == '__main__':
    main()
