import sys
import argparse
import numpy as np
from numpy import exp, log
import matplotlib.pyplot as plt

from tklib.tkutils import replace_path, pint, pfloat
from tklib.tkvariousdata import tkVariousData
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams
from tklib.tksci.tkFit import tkFit
from tklib.tkgraphic.tkplotevent import tkPlotEvent
from tklib.tksci.tksci import kB, e


def initialize():
    """
    argparse の結果 args を受け取り、app.cfg に設定。
    """

    app = tkApplication()
    cfg = tkParams()

    parser = argparse.ArgumentParser(description="Arrhenius plot と多項式フィット")
    parser.add_argument('--infile', type=str, default="Hall-T.xlsx", help='Input file')
    parser.add_argument('--model', type=str, default='simple Arrhenius #log(P)=A-(eEa/kB)*(1/T)', help="モデル選択")
    parser.add_argument('--Tlabel', type=str, default='T(K)', help='T関連データ列ラベル')
    parser.add_argument('--Plabel', type=str, default='P', help='P関連データ列ラベル')
    parser.add_argument('--Ttype', type=str, choices=['T(K)', 'T(C)', '1/T', '1000/T'],
                        default='T(K)', help='Tデータ変換方法')
    parser.add_argument('--Ptype', type=str, choices=['P', 'log10(P)', 'log_e(P)'],
                        default='P', help='Pデータ変換方法')
    parser.add_argument('--xmin', type=float, default=-1.0e100, help='フィットする x の下限')
    parser.add_argument('--xmax', type=float, default=1.0e100, help='フィットする x の上限')
    parser.add_argument('--Tmin', type=float, default=-1.0e100, help='フィットする T の下限')
    parser.add_argument('--Tmax', type=float, default=1.0e100, help='フィットする T の上限')
    parser.add_argument('--Tcalmin', type=str, default='*', help="計算用する の下限（'*' で自動）")
    parser.add_argument('--Tcalmax', type=str, default='*', help="計算用する の上限（'*' で自動）")
    parser.add_argument('--ncal', type=int, default=201, help='Number of points in xcal grid')
    parser.add_argument('--xlsm_template', type=str, default="StandardGraph.xlsm", help='Excel template path')
    parser.add_argument('--plot_ci', type=int, default=1, help='Plot confidence interval')
    parser.add_argument('--plot_sigma_param', type=int, default=1, help='Plot parameter uncertainty band')
    parser.add_argument('--plot_sigma_pred', type=int, default=1, help='Plot prediction uncertainty band')
    parser.add_argument('--plot_sigma_combined', type=int, default=0, help='Plot combined uncertainty band')
    parser.add_argument('--figsize', type=float, nargs=2, default=[12, 8], help='プロット figsize, 例: --figsize 8 8')
    parser.add_argument('--fontsize', type=int, default=16, help='Font size for plots')
    parser.add_argument('--fontsize_legend', type=int, default=12, help='Legend font size')
    group = parser.add_mutually_exclusive_group()
    group.add_argument('--pause', dest='pause', action='store_true', help='Pause on terminate')
    group.add_argument('--no-pause', dest='pause', action='store_false', help='Do not pause on terminate')
    parser.set_defaults(pause=True)

    args = parser.parse_args()
    app.cfg = cfg
    for key in vars(args):
        setattr(cfg, key, getattr(args, key))

    cfg.logfile = app.replace_path(cfg.infile)
#    cfg.outfile = app.replace_path(cfg.infile, template = "{dirname}/{filebody}-out.xlsx")
    cfg.output_fitting_path = app.replace_path(cfg.infile, template = "{dirname}/{filebody}-fit.xlsm")
    cfg.output_parameter_path = app.replace_path(cfg.infile, template = ["{dirname}", "{filebody}-parameters.xlsx"])

    return app, cfg, parser

def build_design_matrix(x, order):
    """
    Construct the design matrix for polynomial regression.

    Args:
        x (array-like): Input data points.
        order (int): Polynomial order.
    Returns:
        np.ndarray: (N, order+1) design matrix.
    """
    x = np.asarray(x)
    N = len(x)
    X = np.ones((N, order + 1))
    for j in range(1, order + 1):
        # Using scalar pow for clarity
        X[:, j] = [pow(val, j) for val in x]
    return X

def mlsq_error(X, y):
    """
    Perform least squares fitting.

    Args:
        X (np.ndarray): Design matrix (N, p).
        y (array-like): Observed values (N,).
    Returns:
        beta (np.ndarray): Estimated coefficients (p,).
        cov_beta (np.ndarray): Covariance matrix of parameters (p, p).
        sigma2_resid (float): Residual variance estimate.
    """
    y = np.asarray(y)
    XtX = X.T @ X
    XtX_inv = np.linalg.inv(XtX)
    beta = XtX_inv @ (X.T @ y)

    # Residuals and variance
    residuals = y - X @ beta
    N, p = X.shape
    RSS = float(residuals.T @ residuals)
    sigma2_resid = RSS / (N - p)
    cov_beta = sigma2_resid * XtX_inv
    beta_std = np.sqrt(np.diag(cov_beta))
    return beta, beta_std, cov_beta, sigma2_resid

def compute_param_uncertainty(X, cov_beta):
    """
    Compute parameter-based prediction variance at each X row.

    Args:
        X (np.ndarray): Design matrix (N, p).
        cov_beta (np.ndarray): Covariance matrix (p, p).
    Returns:
        y_var (np.ndarray): Variance of predictions (N,).
    """
    # Var[y_mean] = diag(X @ cov_beta @ X.T)
    return np.sum((X @ cov_beta) * X, axis=1)

def compute_measurement_error(y):
    """
    Estimate measurement error from input data variance.

    Args:
        y (array-like): Observed values.
    Returns:
        float: Estimated measurement std (unbiased).
    """
    y = np.asarray(y)
    var_unbiased = np.var(y, ddof=1) if len(y) > 1 else 0.0
    return np.sqrt(var_unbiased)

def compute_bands(xcal, beta, cov_beta, sigma2_resid, sigma_meas):
    """
    Compute mean predictions and various uncertainty bands on xcal.

    Args:
        xcal (array-like): Points to evaluate.
        beta (np.ndarray): LSQ coefficients.
        cov_beta (np.ndarray): Parameter covariance matrix.
        sigma2_resid (float): Residual variance.
        sigma_meas (float): Measurement error std.
    Returns:
        dict: y_mean, sigma_param, sigma_pred, sigma_combined arrays.
    """
    Xcal = build_design_matrix(xcal, len(beta) - 1)
    y_mean = Xcal @ beta
    var_param = compute_param_uncertainty(Xcal, cov_beta)
    sigma_param = np.sqrt(var_param)
    sigma_pred = np.sqrt(var_param + sigma2_resid)
    sigma_combined = np.sqrt(var_param + sigma_meas**2)
    return {
        'y_mean': y_mean,
        'sigma_param': sigma_param,
        'sigma_pred': sigma_pred,
        'sigma_combined': sigma_combined
    }

def execute(app):
    cfg = app.cfg

    print(f"Open logfile [{cfg.logfile}]")
    app.redict(targets=["stdout", cfg.logfile], mode='w')

    # Tcalmin/Tcalmax が '*' なら自動設定 later
    cfg.xmin = pfloat(cfg.xmin)
    cfg.xmax = pfloat(cfg.xmax)
    cfg.Tmin = pfloat(cfg.Tmin)
    cfg.Tmax = pfloat(cfg.Tmax)
    # 以下元コード同様にデータ読み込みと変換
    print("#======================================================")
    print("# Analyze activation energy etc by Arrhenius plot")
    print("#======================================================")
    print(f"infile : {cfg.infile}")
    print(f"model  : {cfg.model}")
    print(f"Tlabel : {cfg.Tlabel}, Plabel: {cfg.Plabel}")
    print(f"Ttype  : {cfg.Ttype}, Ptype: {cfg.Ptype}")
    print(f"Fitting x range: {cfg.xmin} - {cfg.xmax}")
    print(f"Fitting T range: {cfg.Tmin} - {cfg.Tmax}")

    if '***' in cfg.model:
        app.terminate("Error: Choose model", pause=cfg.pause)

    # データ読み込み
    print(f"Read [{cfg.infile}]")
    datafile = tkVariousData(cfg.infile)
    labels, datalist = datafile.Read_minimum_matrix(close_fp=True, usage=app.usage)
    label_x, xX = datafile.FindDataArray(cfg.Tlabel, flag='i')
    label_y, yY = datafile.FindDataArray(cfg.Plabel, flag='i')
    if xX is None or yY is None:
        app.terminate("Error: 指定ラベルのデータが見つかりません", pause=cfg.pause)

    # T の変換
    if cfg.Ttype == 'T(K)':
        T = xX
    elif cfg.Ttype == 'T(C)':
        T = [v + 273.15 for v in xX]
    elif cfg.Ttype == '1/T':
        T = [1.0 / v for v in xX]
    elif cfg.Ttype == '1000/T':
        T = [1000.0 / v for v in xX]
    else:
        app.terminate(f"Invalid Ttype [{cfg.Ttype}]", usage=app.usage, pause=cfg.pause)

    # P の変換
    if cfg.Ptype == 'P':
        P = yY
    elif cfg.Ptype == 'log10(P)':
        P = [10**v for v in yY]
    elif cfg.Ptype == 'log_e(P)':
        P = [np.exp(v) for v in yY]
    else:
        app.terminate(f"Invalid Ptype [{cfg.Ptype}]", usage=app.usage, pause=cfg.pause)

    print("T(K)=", T)
    print("P   =", P)

    # フィッティング対象抽出
    T1000 = [1000.0 / v for v in T]
    log10P = [log(v)/log(10.0) for v in P]
    x_fit = []
    y_fit = []
    for xi_orig, xi, yi in zip(xX, T1000, log10P):
        if cfg.xmin <= xi_orig <= cfg.xmax and cfg.Tmin <= (1000.0/xi if xi!=0 else float('inf')) <= cfg.Tmax:
            x_fit.append(xi)
            y_fit.append(yi)

    # 多項式次数決定
    if cfg.model == 'percolation':
        norder = 2
    elif cfg.model == '3rd order':
        norder = 3
    elif cfg.model == '4th order':
        norder = 4
    else:
        norder = 1

    print(f"\nLeast-squares fitting with {norder}-th order polynomial of 1000/T")
    print("Data to be fitted:")
    print(f"  {'1000/T':12} {'log10(P)':12}")
    for xi, yi in zip(x_fit, y_fit):
        print(f"  {xi:12.4g} {yi:12.4g}")

    X = build_design_matrix(x_fit, norder)
    beta, beta_std, cov_beta, sigma2_resid = mlsq_error(X, y_fit)
    print("Fitted polynomial coefficients:")
    for i, coef in enumerate(beta):
        print(f"  c{i} = {coef:g} +- {beta_std[i]:g}")
    print(f"  Residual variance sigma2_resid = {sigma2_resid:12.4g}")

    print(f"Save parameters and stds to {cfg.output_parameter_path}")
    fit = tkFit()
    fit.to_excel(cfg.output_parameter_path, ["coeff", "std"], [beta, beta_std])

    # 測定誤差目安
    sigma_meas = compute_measurement_error(y_fit)
    print(f"Estimated measurement error sigma_meas = {sigma_meas:g}")

    Tcalmin = pfloat(cfg.Tcalmin, defval=min(T))
    Tcalmax = pfloat(cfg.Tcalmax, defval=max(T))
    Tstep = (Tcalmax - Tcalmin) / (cfg.ncal - 1)
    T_plot = [Tcalmin + i * Tstep for i in range(cfg.ncal)]
    x_plot = [1000.0 / v for v in T_plot]
    xcal = np.linspace(1000 / Tcalmax, 1000 / Tcalmin, cfg.ncal)
    ycal = build_design_matrix(xcal, norder) @ beta  # log10P 予測
    P_plot = [10**val for val in ycal]

    fit_obj = tkFit()
    P_fit = [10**val for val in (X @ beta)]
    fit_obj.print_scores(heading="\nScores between P(input) and P(fit)", y1=P, y2=P_fit)
    fit_obj.print_scores(heading="\nScores between log10(P(input)) and log10(P(fit))", y1=log10P, y2=list(X @ beta))

    bands = compute_bands(xcal, beta, cov_beta, sigma2_resid, sigma_meas)

    print(f"Save results to [{cfg.output_fitting_path}]")
    xlabel = label_x
    ylabel = label_y
    X_for_cal = build_design_matrix(T1000, norder)
    ycal = X_for_cal @ beta
    fit.to_excel(cfg.output_fitting_path, 
                 [xlabel, ylabel, "", "1000/T (K^-1)", "log10(P)", "log10(P)(cal)", "",
                  "1000/T (K^-1)", "log10(P)(cal)", "sigma(param)", 'sigma(param&resid)', 'sigma(param&noise)'], 
                 [xX, yY, [], T1000, log10P, ycal, [],
                  xcal, bands['y_mean'], bands['sigma_param'], bands['sigma_pred'], bands['sigma_combined']],
                 template = cfg.xlsm_template)
    """
    fit.to_excel(cfg.output_fitting_path, 
                 [xlabel, ylabel, 'T(K)', 'P', '1000/T (K^-1)', 'log10(P)', "log10(P)(cal)", "", 
                  "T (K)", "P(cal)", "sigma(param)", 'sigma(param&resid)', 'sigma(param&noise)'], 
                 [xX, yY, T, P, T1000, log10P, ycal, [], 
                  xcal, bands['y_mean'], bands['sigma_param'], bands['sigma_pred'], bands['sigma_combined']],
                 template = cfg.xlsm_template)
    """

    """
    print(f"Save to [{cfg.outfile}]")
    fit_obj.to_excel(cfg.outfile,
                     [label_x, label_y, 'T(K)', 'P', '1000/T (K^-1)', 'log10(P)', 'log10(P)(cal)',
                      '', 'T (K)', '1000/T (K^-1)', 'P(cal)', 'log10(P)(cal)'],
                     [xX, yY, T, P, T1000, log10P, list(X @ beta), [],
                      T_plot, x_plot, P_plot, list(ycal)])
    """

    # Ea plot: 数値微分で再計算
    diff_plot = np.gradient(bands['y_mean'], xcal)
    Ea_plot = [-d * kB / e * 1000.0 * log(10.0) for d in diff_plot]

    fig, axes = plt.subplots(2, 3, figsize=cfg.figsize)
    plot_event = tkPlotEvent(plt)
    axes = axes.flatten()
    for ax in axes:
        ax.tick_params(labelsize=cfg.fontsize)

    # (0) X-Y
    axes[0].plot(xX, yY, linestyle='', marker='o', markerfacecolor='black', markersize=5.0)
    axes[0].set_xlabel(label_x, fontsize=cfg.fontsize)
    axes[0].set_ylabel(label_y, fontsize=cfg.fontsize)

    # (1) T-P
    axes[1].plot(T, P, linestyle='', marker='o', markerfacecolor='black', markersize=5.0)
#    axes[1].plot(T_plot, P_plot, linestyle='-', color='red', linewidth=0.5)
    axes[1].set_xlabel('$T$ (K)', fontsize=cfg.fontsize)
    axes[1].set_ylabel('$P$', fontsize=cfg.fontsize)

    # (2) Arrhenius plot
    axes[2].plot(x_fit, y_fit, 'o', label='data', color = 'black', markersize=1.5)
    axes[2].plot(xcal, bands['y_mean'], label='fit', linewidth = 0.5, color='red')
    if cfg.plot_sigma_param:
        axes[2].fill_between(xcal,
                         bands['y_mean'] - bands['sigma_param'],
                         bands['y_mean'] + bands['sigma_param'],
                         color='#0000cc', alpha=0.5, label='±σ(param)')

    if cfg.plot_sigma_pred:
        axes[2].fill_between(xcal,
                         bands['y_mean'] - bands['sigma_pred'],
                         bands['y_mean'] + bands['sigma_pred'],
                         color='#ddddFF', alpha=0.5, label='±σ(param&resid)')

    if cfg.plot_sigma_combined:
        axes[2].fill_between(xcal,
                         bands['y_mean'] - bands['sigma_combined'],
                         bands['y_mean'] + bands['sigma_combined'],
                         color='purple', alpha=0.5, label='±σ(param&noise)')

    axes[2].set_xlabel('$1000/T$ (K$^{-1}$)', fontsize=cfg.fontsize)
    axes[2].set_ylabel(r'$\log_{10}(P)$', fontsize=cfg.fontsize)
    axes[2].legend(fontsize=cfg.fontsize_legend)

    axes[3].plot(xcal, Ea_plot, label='Ea (eV)', linestyle='-', linewidth=0.5, color='black')
    axes[3].set_xlabel('$1000/T$ (K$^{-1}$)', fontsize=cfg.fontsize)
    axes[3].set_ylabel('$E_a$ (eV)', fontsize=cfg.fontsize)
    axes[3].legend(fontsize=cfg.fontsize_legend)
    minEa = min(Ea_plot)
    maxEa = max(Ea_plot)
    if abs(maxEa - minEa) < 1.0e-6:
        minEa -= 1.0e-3
        maxEa += 1.0e-3
        axes[3].set_ylim([minEa, maxEa])

    if axes[4]:
        idxs = np.arange(len(beta))
        axes[4].errorbar(idxs, beta, yerr=beta_std, fmt='o', capsize=3, label='coeff ±1σ')
        axes[4].set_xlabel('$i$', fontsize=cfg.fontsize)
        axes[4].set_ylabel('$c_i$', fontsize=cfg.fontsize)
        axes[4].tick_params(labelsize=cfg.fontsize)
        axes[4].legend(fontsize=cfg.fontsize_legend)

    # plot_event 登録
    '''
    all_data = datalist
    plot_event.add_data({"label": "X-Y plot",       "plot_type": "2D", "axis": axes[0], "data": None,
                         "xlist": all_data, "xlabels": labels})
    plot_event.add_data({"label": "T-P plot",       "plot_type": "2D", "axis": axes[1], "data": None,
                         "xlist": all_data, "xlabels": labels})
    plot_event.add_data({"label": "Arrhenius plot", "plot_type": "2D", "axis": axes[2], "data": None,
                         "xlist": all_data, "xlabels": labels})
    plot_event.add_data({"label": "Ea",             "plot_type": "2D", "axis": axes[3], "data": None})
    plot_event.register_event(fig, event="button_press_event",
                              callback=lambda event: plot_event.onclick(event))
    '''

    plt.tight_layout()
    plt.pause(0.001)

    app.terminate("", usage=None, pause=cfg.pause)

def main():
    app, cfg, parser = initialize()
    execute(app)

if __name__ == "__main__":
    main()
