import sys
import os
from pprint import pprint
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, sqrt, abs
import pandas as pd
from scipy.optimize import minimize
from scipy.signal import savgol_filter 
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet, SGDRegressor
import matplotlib.pyplot as plt
import matplotlib.widgets as wg
import pandas as pd


from tklib.tkutils import replace_path, getarg, getintarg, getfloatarg, pint, pfloat
#from tklib.tksci.tksci import kB, e
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams
from tklib.tkvariousdata import tkVariousData
from tklib.tksci.tkoptimize_linear import Smoothing_Penalty_Regression_by_base
from tklib.tkgraphic.tkplotevent import tkPlotEvent


#=======================================
# プログラム開始
#=======================================
def initialize(app):
    cparams = tkParams()
    app.cparams = cparams

    cparams.infile = ''

    cparams.mode = 'fit'
    cparams.nsmooth = 5
    cparams.norder = 2

#nelder-mead    Downhill simplex
#powell         Modified Powell
#cg             conjugate gradient (Polak-Ribiere method)
#bfgs           BFGS法
#newton-cg      Newton-CG
#trust-ncg      信頼領域 Newton-CG 法
#dogleg         信頼領域 dog-leg 法
#L-BFGS-B’ (see here)
#TNC’ (see here)
#COBYLA’ (see here)
#SLSQP’ (see here)
#trust-constr’(see here)
#dogleg’ (see here)
#trust-exact’ (see here)
#trust-krylov’ (see here)
    cparams.method = "nelder-mead"

# For LSQ
    cparams.b0  = 0.0
#    cparams.A0  = [1.0, 0.5, 0.2]
#    cparams.tau = [0.1, 0.3, 1.0]
    cparams.A0  = [1.1,  0.5, 0.2, 0.2]
    cparams.tau = [0.05, 0.2, 0.5, 0.8]


# For tau estimation by single decay model
    cparams.nLSQ_tau = 7

# For minimize
    cparams.h = 1.0e-2
    cparams.tol = 1.0e-5
    cparams.nmaxiter = 1000
    cparams.plot_interval = 20

# For Ridge / LASSO
    cparams.alpha = 0.1
    cparams.tau0  = 1.0e-3
    cparams.tau1  = 2.0
    cparams.ntau  = 101

    cparams.xgmin  =  0.0
    cparams.xgmax  =  1.0
    cparams.ngdata = 101

    cparams.figsize = [10, 8]
    cparams.fontsize = 12
    cparams.legend_fontsize = 10

    cparams.xmin = -1.0e100
    cparams.xmax =  1.0e100

def update_vars(app):
    cparams = app.cparams

    app.add_argument(opt = "-mode", type = "str", var_name = 'mode',  opt_str = "-mode=fit", 
                     desc = 'mode: [plot|fit]',
                     defval = "fit", optional = True);

    app.add_argument(opt = "-method", type = "str", var_name = 'method',  opt_str = "-mode=nelder-mead", 
                     desc = 'method: [nelder-mead|cg|bfgs]',
                     defval = "nelder-mead", optional = True);

    app.add_argument(opt = "-xlabel", type = "str", var_name = 'xlabel',  opt_str = "-xlabel=0",
                     desc = 'Data column for x', defval = "0", optional = True);
    app.add_argument(opt = "-ylabel", type = "str", var_name = 'ylabel',  opt_str = "-ylabel=1",
                     desc = 'Data column for y', defval = "0", optional = True);

    app.add_argument(opt = "-xmin", type = "double", var_name = 'xmin',  opt_str = "-xmin=-1.0e-100",
                     desc = 'Lower limit of T-related data (x) for fitting', defval = "-1.0e-100", optional = True);
    app.add_argument(opt = "-xmax", type = "double", var_name = 'xmax',  opt_str = "-xmax=1.0e-100",
                     desc = 'Upper limit of T-related data (x) for fitting', defval = "1.0e-100", optional = True);

# smoothing
    app.add_argument(opt = "-nsmooth", type = "int", var_name = 'nsmooth',  opt_str = "-nsmooth=5", 
                     desc = 'nsmooth: Number of data for smoothing to calculate differential',
                     defval = 5, optional = True);

    app.add_argument(opt = "-norder", type = "int", var_name = 'norder',  opt_str = "-norder=2", 
                     desc = 'norder: Order of polynomial for smoothing',
                     defval = 2, optional = True);

# minimize
    app.add_argument(opt = "-h", type = "float", var_name = 'h',  opt_str = "-h=1.0e-2", 
                     desc = 'h: Finite difference for numerical differentiation',
                     defval = 1.0e-2, optional = True);

    app.add_argument(opt = "-b0", type = "float", var_name = 'b0',  opt_str = "-b0=0.0", 
                     desc = 'b0: Initial value of constant baseline',
                     defval = 0.0, optional = True);

    app.add_argument(opt = "-pA0", type = "str", var_name = 'pA0',  opt_str = "-pA0='1.0,0.5,0.2'", 
                     desc = 'pA0: Initial values of intensities separated by comma',
                     defval = "1.0,0.5,0.2", optional = True);

    app.add_argument(opt = "-ptau0", type = "str", var_name = 'ptau0',  opt_str = "-ptau0='1.0,0.5,0.2'", 
                     desc = 'ptau0: Initial relaxation time values separated by comma',
                     defval = "1.0,0.5,0.2", optional = True);

    app.add_argument(opt = "-nmaxiter", type = "int", var_name = 'nmaxiter',  opt_str = "-nmaxiter=1000'", 
                     desc = 'nmaxiter: Maximum number of iteration for Ridge/LASSO/minimize]',
                     defval = 1000, optional = True);

    app.add_argument(opt = "-tol", type = "double", var_name = 'tol',  opt_str = "-tol=1.0e-5", 
                     desc = 'tol: Convergence criterion for Ridge/LASSO/minimize]',
                     defval = 1.0e-5, optional = True);

# Single decay time approximation
    app.add_argument(opt = "-nLSQ", type = "int", var_name = 'nLSQ',  opt_str = "-nLSQ=7", 
                     desc = 'nLSQ: Number of data to be used for estimating tau by single decay model',
                     defval = 7, optional = True);

# LASSO
    app.add_argument(opt = "-ntau", type = "int", var_name = 'ntau',  opt_str = "-ntau=101", 
                     desc = 'ntau: Number of relaxation time for Ridge/LASSO regression',
                     defval = 101, optional = True);

    app.add_argument(opt = "-tau0", type = "double", var_name = 'tau0',  opt_str = "-tau0=1.0e-2'", 
                     desc = 'tau0: Lower limit ot tau for Ridge/LASSO descriptors]',
                     defval = 1.0e-2, optional = True);

    app.add_argument(opt = "-tau1", type = "double", var_name = 'tau1',  opt_str = "-tau1=1.0'", 
                     desc = 'tau1: Upper limit ot tau for Ridge/LASSO descriptors]',
                     defval = 1.0, optional = True);

    app.add_argument(opt = "-alpha", type = "double", var_name = 'alpha',  opt_str = "-alpha=0.1", 
                     desc = 'alpha: Ridge/LASSO alpha parameter',
                     defval = 0.1, optional = True);

    app.add_argument(opt = None, type = "str", var_name = 'infile', opt_str = "infile",  desc = 'Input file', 
                     defval = "decay.xlsx", optional = True);


    args_opt, args_idx, args_vars = app.read_args(vars = cparams.__dict__, check_allowed_args = True)
    if args_opt is None:
        error_no = args_idx
        error_message = args_vars
        app.terminate("\n\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
                   + f"!  {error_message}\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",
                   usage = app.usage)

#    app.set_usage(usage_str)

class MinimizeFunc():
    def __init__(self, cparams):
        self.cparams = cparams

        self.stop_flag = False

        self.X   = None
        self.Y   = None
        self.b0  = cparams.b0
        self.A0  = cparams.A0
        self.tau = cparams.tau
        
        self.h = 0.01

        self.iter = 0
        self.ax = None
        self.xg = None
    
    def fitting_func(self, x, pk):
        n = len(pk)
        ntau = int((n - 1) / 2 + 1.0e-5)

        f = pk[0]
        for i in range(ntau):
            A0  = pk[i + 1]
            tau = pk[i + 1 + ntau]
#            print(f"{x=}  {tau=}  {x/tau=}")
            xx = x / tau
            if xx > 300.0:
                pass
            else:
                f += A0 * exp(-xx)

        return f

    def lsqfunc(self, idx, x, tau):
        if idx == 0: return 1.0
        return exp(-x / tau[idx - 1])

    def mlsq_general(self, x, y, m, lsqfunc, print_level = 1):
        n = len(x)
        Si  = np.empty([m, 1])
        Sij = np.empty([m, m])
#        print("n=", n, "  m=", m)

        for l in range(m):
            Si[l, 0] = sum([y[i] * lsqfunc(l, x[i]) for i in range(n)])

        for j in range(m):
            for l in range(j, m):
                v = sum([lsqfunc(j, x[i]) * lsqfunc(l, x[i]) for i in range(n)])
                Sij[j, l] = Sij[l, j] = v

        if print_level:
            print("MinimizeFunc.mslq_general:: Vector and Matrix:") 
            print("Si=")
            pprint(Si)
            print("Sij=")
            pprint(Sij)
            print("")

        ci = np.linalg.inv(Sij) @ Si
        ci = ci.transpose().tolist()

        return ci[0], Si, Sij

    def minimize_func(self, pk, weight = None):
        f = 0.0
        for i in range(len(self.X)):
            y = self.fitting_func(self.X[i], pk)
# ベースラインを引いたシグナル S をつかって 1/S^2 の重みを加えている
            if weight is None:
                S = self.Y[i] - pk[0]
            else:
                S = 1.0
            f += (y - self.Y[i])**2 / S / S
#            f += (y - self.Y[i])**2 / S / S / abs(y)
        return f

# １次微分を定義するとcgやbfgsなどの勾配法を使える
    def diff1(self, pk):
        diff = pk.copy()
        nvar = len(pk)

        for i in range(nvar):
            xx = pk.copy()
            abspk = abs(pk[i])
            if abspk == 0.0:
                _h = self.h
            else:
                _h = self.h * abspk
                
            xx[i] = pk[i] - _h
            ym = self.minimize_func(xx)
        
            xx[i] = pk[i] + _h
            yp = self.minimize_func(xx)
            diff[i] = (yp - ym) / 2.0 / _h

        return diff

    def callback(self, pk):
        if self.stop_flag:
            return False

        w = self.plt.get_current_fig_manager().window
        if w != self.window:
            return False


        n = len(pk)
        ntau = int((n - 1) / 2 + 1.0e-5)
        fmin = self.minimize_func(pk)
        print(f"callback {self.iter}: func={fmin}")
        print(f"    b0={pk[0]:10.4g}")
        for i in range(ntau):
            A0  = pk[i + 1]
            tau = pk[i + 1 + ntau]
            print(f"    #{i:2}: tau={tau:10.4g} A0={A0:10.4g}")

        self.iter += 1

        ycal = [self.fitting_func(_x, pk) for _x in self.xg]
        self.fmin.append(fmin)

        if self.iter % self.cparams.plot_interval == 0:
            self.data[0].set_data(self.xg, ycal)
#            self.axes[0].plot(self.xg, ycal, color = 'blue', linestyle = '-', linewidth = 0.2)
        self.axes[1].plot(range(self.iter + 1), self.fmin, linestyle = '',  marker = 'x', markersize = 2.0)
        self.axes[1].set_xlabel('iteration', fontsize = self.cparams.fontsize)
        self.axes[1].set_ylabel('error', fontsize = self.cparams.fontsize)
        plt.pause(0.01)

        return True

def fit(app):
    cparams = app.cparams

    cparams.logfile = app.replace_path(cparams.infile)
    cparams.outfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-out.xlsx"])
    print(f"Open logfile [{cparams.logfile}]")
    app.redirect(targets = ["stdout", cparams.logfile], mode = 'w')

    cparams.xmin = pfloat(cparams.xmin)
    cparams.xmax = pfloat(cparams.xmax)
    cparams.tol = pfloat(cparams.tol)

    min_func = MinimizeFunc(cparams)
    cparams.alpha = pfloat(cparams.alpha)

    print("")
    print("#======================================================")
    print("# Analyze decay responce with various decay time")
    print("#======================================================")
    print(f"mode    : {cparams.mode}")
    print(f"method  : {cparams.method}")
    print(f"infile : {cparams.infile}")
    print(f"outfile: {cparams.outfile}")
    print(f"logfile: {cparams.logfile}")
    print(f"xlabel  : {cparams.xlabel}")
    print(f"ylabel  : {cparams.ylabel}")
    print(f"Fitting x range: {cparams.xmin} - {cparams.xmax}")
    print( "Smoothing:")
    print(f"  nsmooth : {cparams.nsmooth}")
    print(f"  norder  : {cparams.norder}")

    if '***' in cparams.method:
        app.terminate("\nError: Choose method\n", usage = app.usage, pause = True)

    if type(cparams.xlabel) is str and '**' in cparams.xlabel:
        print("")
        print("Error: Choose xlabel")
        app.terminate("", usage = app.usage, pause = True)

    if type(cparams.ylabel) is str and '**' in cparams.ylabel:
        print("")
        print("Error: Choose ylabel")
        app.terminate("", usage = app.usage, pause = True)

    if type(cparams.xmin) is str and cparams.xmin == "": cparams.xmin = None
    if type(cparams.xmax) is str and cparams.xmax == "": cparams.xmax = None

    print( "LASSO:")
    print(f"  tau range: {cparams.tau0} - {cparams.tau1}")
    print(f"  ntau     : {cparams.ntau}")
    print(f"  alpha    : {cparams.alpha}")

    print( "Non-linear LSQ:")
    A0s  = cparams.pA0.split(',')
    taus = cparams.ptau0.split(',')
    cparams.A0  = [pfloat(A0s[i])  for i in range(len(A0s))]
    cparams.tau = [pfloat(taus[i]) for i in range(len(taus))]
    print( "  h   :", cparams.h)
    print( "  b0  :", cparams.b0)
    print( "  A0  :", cparams.A0)
    print( "  tau0:", cparams.tau)
    print(f"  nmaxiter: {cparams.nmaxiter}")
    print(f"  tol     : {cparams.tol}")

    datafile = tkVariousData(cparams.infile)
    labels, datalist = datafile.Read_minimum_matrix(close_fp = True, usage = app.usage)
    label_x, _xX = datafile.FindDataArray(cparams.xlabel, flag = 'i')
    label_y, _yY = datafile.FindDataArray(cparams.ylabel, flag = 'i')
    nalldata = len(_xX)
    
    xX = []
    yY = []
    for i in range(nalldata):
        if cparams.xmin is not None and _xX[i] < cparams.xmin:
            continue
        if cparams.xmax is not None and cparams.xmax < _xX[i]:
            continue

        xX.append(_xX[i])
        yY.append(_yY[i])

    ndata = len(xX)
    print("")
    print("ndata:", ndata)

    min_func.X  = xX
    min_func.Y  = yY
    min_func.h   = cparams.h
    min_func.b0  = cparams.b0
    min_func.A0  = cparams.A0
    min_func.tau = cparams.tau
    p0s = [cparams.b0, *cparams.A0, *cparams.tau]
    print("initial parametes: p0s=", p0s)

# グラフに表示する関数
    cparams.xgmin = min(xX)
    cparams.xgmax = max(xX)
    xgstep = (cparams.xgmax - cparams.xgmin) / (cparams.ngdata - 1)
    min_func.xg = [cparams.xgmin + xgstep * i for i in range(cparams.ngdata)]
    yini = [min_func.fitting_func(_x, p0s) for _x in min_func.xg]

#関数のグラフ
    min_func.plt = plt
    fig, axes = plt.subplots(2, 1, figsize = cparams.figsize)
#    axes = axes.flatten()
    min_func.axes = axes

    axes[0].tick_params(labelsize = cparams.fontsize)
    axes[1].tick_params(labelsize = cparams.fontsize)
#    axes[2].tick_params(labelsize = cparams.fontsize)

    axes[0].plot(xX, yY,   label = 'raw data', linestyle = '',  marker = 'x', markersize = 3.0)
    axes[0].plot(min_func.xg, yini, label = 'initial',  linestyle = 'dashed', linewidth = 0.5, color = 'red')
    min_func.data = axes[0].plot([], [], label = 'fitted',  linestyle = 'dashed', linewidth = 1.0, color = 'blue')
    axes[0].set_xlabel(label_x, fontsize = cparams.fontsize)
    axes[0].set_ylabel(label_y, fontsize = cparams.fontsize)
    axes[0].set_yscale('log')

    min_func.fmin = [min_func.minimize_func(p0s)]
    axes[1].plot([0], min_func.fmin, linestyle = '',  marker = 'x', markersize = 2.0)
    axes[1].set_xlabel('iteration', fontsize = cparams.fontsize)
    axes[1].set_ylabel('error', fontsize = cparams.fontsize)
    axes[1].set_yscale('log')

# tight_layoutをするとsubplots_adjust()が無効化されるので、この順番で設定する
    axes[0].legend()

    plt.tight_layout()
    plt.subplots_adjust(top = 0.93, bottom = 0.10)

    def button_click(e):
        min_func.stop_flag = True

    ax_button = plt.axes([0.15, 0.95, 0.15, 0.03])
    button = wg.Button(ax_button, 'stop', color = '#f8e58c', hovercolor = '#38b48b')
    button.on_clicked(button_click)

    plt.pause(0.001)
    min_func.window = plt.get_current_fig_manager().window

    print("")
    print("Minimize by linear LSQ:")
    ci, Si, Sij = min_func.mlsq_general(min_func.X, min_func.Y, len(cparams.tau) + 1,
                    lambda idx, x: min_func.lsqfunc(idx, x, cparams.tau))
    print("tau=", cparams.tau)
    print("ci=", ci)
    min_func.Yllsq = []
    for x in min_func.xg:
        y = 0.0
        for idx in range(len(cparams.tau)):
            y += ci[idx] * min_func.lsqfunc(idx, x, cparams.tau)
        min_func.Yllsq.append(y)

    axes[0].plot(min_func.xg, min_func.Yllsq, label = 'linear LSQ', linestyle = 'dashed', color = 'green')
    axes[0].legend()
    plt.pause(0.001)

    print("")
    print("Minimize by non-linear LSQ:")
    p0s = [*ci, *cparams.tau]
    print("  next parameters: p0s=", p0s)

#    diff1 = "3-point"
    diff1 = lambda pk: min_func.diff1(pk)
    res = minimize(lambda pk: min_func.minimize_func(pk), p0s, 
                jac = diff1, 
                method = cparams.method, tol = cparams.tol, 
                callback = lambda pk: min_func.callback(pk),
                options = {'maxiter': cparams.nmaxiter, "disp":True})
#    print("")
#    print(res)
    print("")
    if res.success:
        print(f"Converged at {res.nit} iteration with y_min={res.fun}")
#        print(f"   at y={res.fun}")
#        print(f"   with x={res.x}")
#        print(f" iteration: {res.nit}")
    else:
        print(f"Function did not converge")
#        print(res)

    print("Final parameters")
#    print("  res.x=", res.x)
    b0 = res.x[0]
    aa = res.x[1:]
    n = int(len(aa) / 2 + 1.0e-5)
#    print("  # of tau:", n)
# p0s = [cparams.b0, *cparams.A0, *cparams.tau]
    print(f"  b0: {b0}")
    for i in range(n):
        A0  = aa[i]
        tau = aa[n + i]
        print(f"  tau{i+1}={tau:10.4g}  A0={A0:10.4g}")

    yfin = [min_func.fitting_func(_x, res.x) for _x in min_func.xg]


    base_name = os.path.splitext(cparams.infile)[0]
    outfile = f'{base_name}-fitting.xlsx'
    print()
    print(f"Save fitting results to [{outfile}]")
    pd.DataFrame(np.array([min_func.xg, yini, min_func.Yllsq, yfin]).T, columns = ["t", "y(ini)", "y(llsq)", "yfinal"]).to_excel(outfile)
    
#=========================================
# plot
#=========================================
    w = plt.get_current_fig_manager().window
    if True:
#    if min_func.window == w:
#        min_func.data.set_data(min_func.xg, yfin)
        axes[0].plot(min_func.xg, yfin, label = 'final',  linestyle = '-', linewidth = 1.0, color = 'blue')

        plt.tight_layout()
        plt.subplots_adjust(top = 0.93, bottom = 0.10)
        plt.pause(0.001)

    app.terminate("", usage = app.usage, pause = True)

def estimate_taus(nLSQ, xX, yY, cparams, print_level = 1):
    ndata = len(xX)
    
    if print_level:
        print("")
        print(f"Rough estimation of tau(x) by single decay model with [{nLSQ}] data points")

    x_tau   = []
    tau_tau = []
    tau_b0  = []
    xs_tau  = []
    ys_tau  = []
    min_func = MinimizeFunc(cparams)
    _b0  = yY[ndata-1]
    _A0  = yY[0]
    _tau = -(xX[1] - xX[0]) / (log(yY[1] - _b0) - log(yY[0] - _b0))

    if print_level:
        print(f"Initial values: b0={_b0:8.3g} A0={_A0:8.3g} tau={_tau:8.3g}")

#    diff1 = lambda pk: min_func.diff1(pk)
    diff1 = '3-point'
    for i in range(0, ndata, nLSQ):
        if i + nLSQ >= ndata:
            break

        icenter = int(i + nLSQ/2)

        xcenter = xX[icenter]
        min_func.X  = xX[i:i+nLSQ+1]
        min_func.Y  = yY[i:i+nLSQ+1]
        p0s = [_b0, _A0, _tau]
        
        converged = False
        def callback(pk):
            nonlocal converged
            
            if converged: return False

            fmin = min_func.minimize_func(pk, weight = 1.0)
            if fmin < cparams.tol: 
                converged = True
                return False

        if cparams.method == 'nelder-mead':
            res = minimize(lambda pk: min_func.minimize_func(pk, weight = 1.0), p0s, 
                method = cparams.method, tol = cparams.tol, 
                callback = callback,
#                callback = lambda pk: min_func.callback(pk),
                options = {'maxiter': cparams.nmaxiter, "disp":False})
        else:
            res = minimize(lambda pk: min_func.minimize_func(pk, weight = 1.0), p0s, 
                jac = diff1,
                method = cparams.method, tol = cparams.tol, 
                callback = callback,
#                callback = lambda pk: min_func.callback(pk),
                options = {'maxiter': cparams.nmaxiter, "disp":False})

#        method = 'cg'
        if res.success or converged:
            if print_level:
                print(f"OK i={icenter:3} x={xcenter:8.3g} b0={_b0:8.3g} A0={_A0:8.3g} tau={_tau:8.3g}  fmin={res.fun:.6g}")

            _b0  = res.x[0]
            _A0  = res.x[1]
            _tau = res.x[2]
            x_tau.append(xcenter)
            tau_tau.append(_tau)
            tau_b0.append(_b0)
            yfin = [min_func.fitting_func(_x, res.x) for _x in min_func.X]
            xs_tau.append(min_func.X)
            ys_tau.append(yfin)
        else:
            print(f"***Warning in decay.estimate_taus(): Not converged with S2={res.fun} for max iteration of {res.nit}")
            print(f"  NO i={icenter:3} x={xcenter:8.3g} b0={_b0:8.3g} A0={_A0:8.3g} tau={_tau:10.3g}")

    return x_tau, tau_tau, tau_b0, xs_tau, ys_tau

def plot(app):
    cparams = app.cparams

    cparams.logfile = app.replace_path(cparams.infile)
    cparams.outfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-out.xlsx"])
    print(f"Open logfile [{cparams.logfile}]")
    app.redirect(targets = ["stdout", cparams.logfile], mode = 'w')

    cparams.xmin = pfloat(cparams.xmin)
    cparams.xmax = pfloat(cparams.xmax)
    cparams.tol  = pfloat(cparams.tol)
    cparams.alpha = pfloat(cparams.alpha)

    print("")
    print("#======================================================")
    print("# Plot decay responce with various decay time")
    print("#======================================================")
    print(f"mode    : {cparams.mode}")
    print(f"method  : {cparams.method}")
    print(f"infile  : {cparams.infile}")
    print(f"outfile : {cparams.outfile}")
    print(f"logfile : {cparams.logfile}")
    print(f"xlabel  : {cparams.xlabel}")
    print(f"ylabel  : {cparams.ylabel}")
    print(f"x range : {cparams.xmin} - {cparams.xmax}")
    print(f"nsmooth : {cparams.nsmooth}")
    print(f"norder  : {cparams.norder}")
    print(f"alpha   : {cparams.alpha}")
    print(f"nmaxiter: {cparams.nmaxiter}")
    print(f"tol     : {cparams.tol}")

    if type(cparams.xlabel) is str and '**' in cparams.xlabel:
        print("")
        print("Error: Choose xlabel")
        app.terminate("", usage = app.usage, pause = True)
    if type(cparams.ylabel) is str and '**' in cparams.ylabel:
        print("")
        print("Error: Choose ylabel")
        app.terminate("", usage = app.usage, pause = True)

    if type(cparams.xmin) is str and cparams.xmin == "": cparams.xmin = None
    if type(cparams.xmax) is str and cparams.xmax == "": cparams.xmax = None
        
    print(f"")
    print(f"Read [{cparams.infile}]")
    datafile = tkVariousData(cparams.infile)
    labels, datalist = datafile.Read_minimum_matrix(close_fp = True, usage = app.usage)
    label_x, _xX = datafile.FindDataArray(cparams.xlabel, flag = 'i')
    label_y, _yY = datafile.FindDataArray(cparams.ylabel, flag = 'i')
    nalldata = len(_xX)
    
    xX = []
    yY = []
    for i in range(nalldata):
        print("cp=", cparams.xmin)
        if cparams.xmin is not None and _xX[i] < cparams.xmin:
            continue
        if cparams.xmax is not None and cparams.xmax < _xX[i]:
            continue

        xX.append(_xX[i])
        yY.append(_yY[i])

    ndata = len(xX)
    print("")
    print("ndata:", ndata)
    

# Smoothing and differentiate to calculate tau(x)
#    miny = min(yY) #- 1.0e-3 * max(yY)
    miny = 0.0
    logy = [log(yY[i] - miny) for i in range(ndata)]
    print("y=", len(logy), cparams.nsmooth, cparams.norder)
    tau_smoothed = savgol_filter(logy, cparams.nsmooth, cparams.norder, deriv = 1)
    tau_smoothed = -1.0 / tau_smoothed

# Estimating tau(x) by single decay model
    nLSQ = cparams.nLSQ_tau
    x_tau, tau_tau, tau_b0, xs_tau, ys_tau = estimate_taus(cparams.nLSQ_tau, xX, yY, cparams, print_level = 1)

# Ridge/LASSO regression
#    cparams.tau0 = min(tau_smoothed)
#    cparams.tau0 = xX[1]
#    cparams.tau1 = max(tau_smoothed)
    cparams.tau0 = float(cparams.tau0)
    cparams.tau1 = float(cparams.tau1)
    tau_step = (cparams.tau1 - cparams.tau0) / (cparams.ntau - 1)
    taus = [cparams.tau0 + i * tau_step for i in range(cparams.ntau)]

    logtau0 = log(cparams.tau0)
    logtau1 = log(cparams.tau1)
    logtau_step = (logtau1 - logtau0) / (cparams.ntau - 1)
    taus_log = [exp(logtau0 + i * logtau_step) for i in range(cparams.ntau)]

# Make descriptors
    xexps = []
    for i in range(cparams.ntau):
        tau = taus[i]
        ys = [exp(-xX[j] / tau) for j in range(ndata)]
#        ys = [tau * exp(-xX[j] / tau) for j in range(ndata)]
        xexps.append(ys)

    xexps_log = []
    for i in range(cparams.ntau):
        tau = taus_log[i]
        ys = [exp(-xX[j] / tau) for j in range(ndata)]
#        ys = [tau * exp(-xX[j] / tau) for j in range(ndata)]
        xexps_log.append(ys)

    x_df     = pd.DataFrame(np.array(xexps).T)
    x_log_df = pd.DataFrame(np.array(xexps_log).T)
    y_df     = pd.DataFrame(np.array([yY]).T)
#    print("x_df=\n", x_df)
#    print("y_df=\n", y_df)

    scaler = StandardScaler()
    scaler.fit(x_df)
    x_scaled = scaler.transform(x_df)

# Ridge regression
    print("")
    print("Ridge regression for linear scale")
    model = Ridge(alpha = cparams.alpha)

    model.fit(x_scaled, y_df)
    b0_ridge = model.intercept_[0]
    ci_ridge = model.coef_[0].copy()

    print("Parameters:")
    print(f"  intercept: {b0_ridge:10.4g}")
    print( "  coefficients")
    for i in range(cparams.ntau):
        tau = taus[i]
        c   = ci_ridge[i]
        print(f"  tau={tau:8.3g}: {c:12.4g}")

    y_cal_ridge = model.predict(x_scaled)

# LASSO regression
    print("")
    print("LASSO regression for linear scale")
    cparams.alpha = float(cparams.alpha)
    model = Lasso(alpha = cparams.alpha, max_iter = cparams.nmaxiter, tol = cparams.tol)

    model.fit(x_scaled, y_df)
    b0_lasso = model.intercept_[0]
    ci_lasso = model.coef_.copy()

    print("Parameters:")
    print(f"  intercept: {b0_lasso:10.4g}")
    print( "  coefficients")
    for i in range(cparams.ntau):
        tau = taus[i]
        c   = ci_lasso[i]
        print(f"  tau={tau:8.3g}: {c:12.4g}")

    y_cal_lasso = model.predict(x_scaled)

    scaler = StandardScaler()
    scaler.fit(x_df)
    x_scaled = scaler.transform(x_df)

# Ridge regression
    print("")
    print("Ridge regression for log(x) scale")
    scaler.fit(x_log_df)
    x_log_scaled = scaler.transform(x_log_df)

    model = Ridge(alpha = cparams.alpha)
    model.fit(x_log_scaled, y_df)
    b0_log_ridge = model.intercept_[0]
    ci_log_ridge = model.coef_[0].copy()

    print("Parameters:")
    print(f"  intercept: {b0_log_ridge:10.4g}")
    print( "  coefficients")
    for i in range(cparams.ntau):
        tau = taus_log[i]
        c   = ci_log_ridge[i]
        print(f"  tau={tau:8.3g}: {c:12.4g}")

    y_cal_ridge_log = model.predict(x_log_scaled)

# Smoothing penalty regression
    def func_base(x, tau):
         return exp(-x / tau)

    do_sp = False
    if do_sp:
        nsmooth = 1
        alpha = cparams.alpha * 0.0001
        ci_log_sp, Si, Sij = Smoothing_Penalty_Regression_by_base(taus_log, y_df.to_numpy(), len(taus_log), 
                    lambda i, x: func_base(x, taus_log[i]), alpha, nsmooth)

# LASSO regression
    print("")
    print("LASSO regression for log(x) scale")
    model = Lasso(alpha = cparams.alpha, max_iter = cparams.nmaxiter, tol = cparams.tol)

    model.fit(x_log_scaled, y_df)
    b0_log_lasso = model.intercept_[0]
    ci_log_lasso = model.coef_.copy()

    print("Parameters:")
    print(f"  intercept: {b0_log_lasso:10.4g}")
    print( "  coefficients")
    for i in range(cparams.ntau):
        tau = taus_log[i]
        c   = ci_log_lasso[i]
        print(f"  tau={tau:8.3g}: {c:12.4g}")

    y_cal_lasso_log = model.predict(x_log_scaled)

#=========================================
# plot
#=========================================
    fig, axes = plt.subplots(2, 3, figsize = cparams.figsize)
    plot_event = tkPlotEvent(plt)

    axes = axes.flatten()
    axes[0].tick_params(labelsize = cparams.fontsize)
    axes[1].tick_params(labelsize = cparams.fontsize)
    axes[2].tick_params(labelsize = cparams.fontsize)
    axes[3].tick_params(labelsize = cparams.fontsize)
    ax2 = axes[3].twinx()
    axes[4].tick_params(labelsize = cparams.fontsize)
    axes[5].tick_params(labelsize = cparams.fontsize)

    input_data = axes[0].plot(xX, yY,         label = 'raw data', linestyle = '',  marker = 'x', markersize = 2.0)
    fit_single_decay = axes[0].plot(xs_tau[0], ys_tau[0], label = 'fit by single decay', linestyle = '-', linewidth = 0.5, color = 'red')
    axes[0].plot(xs_tau[1:], ys_tau[1:], linestyle = '-', linewidth = 0.5, color = 'red')
    axes[0].set_xlabel(label_x, fontsize = cparams.fontsize)
    axes[0].set_ylabel(label_y, fontsize = cparams.fontsize)
    axes[0].legend(fontsize = cparams.legend_fontsize)
    plot_event.add_data({"label": "input",                  "plot_type": "2D", "axis": axes[0], "data": input_data})
    plot_event.add_data({"label": "fit by single decay",    "plot_type": "2D", "axis": axes[0], "data": fit_single_decay})

    input_data1 = axes[1].plot(xX, yY,   label = 'raw data', linestyle = '',  marker = 'x', markersize = 2.0)
    axes[1].set_xlabel(label_x, fontsize = cparams.fontsize)
    axes[1].set_ylabel(label_y, fontsize = cparams.fontsize)
    axes[1].set_xscale('log')
    plot_event.add_data({"label": "input",                  "plot_type": "2D", "axis": axes[1], "data": input_data1})

    input_data2 = axes[2].plot(xX, yY,   label = 'raw data', linestyle = '',  
                    marker = 'x', markersize = 1.0, markerfacecolor = 'red', markeredgecolor = 'red')
    fit_ridge = axes[2].plot(xX, y_cal_ridge_log, label = 'Ridge', linestyle = '-', linewidth = 0.5, color = 'blue') 
    fit_lasso = axes[2].plot(xX, y_cal_lasso_log, label = 'LASSO', linestyle = '-', linewidth = 0.5, color = 'red') 
    axes[2].set_xlabel(label_x, fontsize = cparams.fontsize)
    axes[2].set_ylabel(label_y, fontsize = cparams.fontsize)
    axes[2].set_xscale('log')
    axes[2].set_yscale('log')
    axes[2].legend(fontsize = cparams.legend_fontsize)
    plot_event.add_data({"label": "input",           "plot_type": "2D", "axis": axes[2], "data": input_data2})
    plot_event.add_data({"label": "fit by Ridge",    "plot_type": "2D", "axis": axes[2], "data": fit_ridge})
    plot_event.add_data({"label": "fit by LASSO",    "plot_type": "2D", "axis": axes[2], "data": fit_lasso})

#    axes[3].plot(xX, tau_smoothed, label = '$\\tau$', linestyle = '',  marker = 'x', markersize = 2.0)
    ins1 = ax2.plot(x_tau, tau_tau, label = '$\\tau$', linestyle = '',  marker = 'x', markersize = 5.0)
    ins2 = axes[3].plot(x_tau, tau_b0, label = '$b_0$', 
                linestyle = '',  marker = 'o', markersize = 5.0, markerfacecolor = 'red', markeredgecolor = 'red')
    axes[3].set_xlabel(label_x, fontsize = cparams.fontsize)
    axes[3].set_ylabel("$b_0$", fontsize = cparams.fontsize, color = 'red')
    axes[3].tick_params(axis = 'y', labelcolor = 'red')
    ax2.set_ylabel("$\\tau$", fontsize = cparams.fontsize, color = 'blue')
    ax2.tick_params(axis = 'y', labelcolor = 'blue')
    ax2.set_yscale('log')
    ins = ins1 + ins2
    axes[3].legend(ins, [l.get_label() for l in ins], fontsize = cparams.legend_fontsize, loc = 'lower right') #loc = 'upper center') #loc = 'best')
#    plot_event.add_data({"label": "tau", "plot_type": "2D", "axis": ax2, "data": ins1})
    plot_event.add_data({"label": "tau", "plot_type": "2D", "axis": ax2, "data": ins1, "axis_scale": ax2,
                         "x": x_tau, 'y': tau_tau,
                         "xlist": [x_tau, tau_tau], "xlabels": ["t", "tau"]})
#    plot_event.add_data({"label": "b0",  "plot_type": "2D", "axis": ax2, "data": ins2, "y_lim": axes[3].get_ylim()})
    plot_event.add_data({"label": "b0",  "plot_type": "2D", "axis": ax2, "data": ins2, "axis_scale": axes[3],
                         "x": x_tau, 'y': tau_b0,
                         "xlist": [x_tau, tau_b0], "xlabels": ["t", "b0"]})

    taus_linear_ridge = axes[4].plot(taus, ci_ridge, label = '$c_i$ (Ridge)', linestyle = '-', linewidth = 0.5, marker = 'o', markersize = 2.0)
    taus_linear_lasso = axes[4].plot(taus, ci_lasso, label = '$c_i$ (LASSO)', linestyle = '-', linewidth = 0.5, marker = 'o', markersize = 2.0)
    axes[4].set_xlabel("$\\tau$", fontsize = cparams.fontsize)
    axes[4].set_ylabel("$c_i$",   fontsize = cparams.fontsize)
    axes[4].set_yscale('log')
    axes[4].legend(fontsize = cparams.legend_fontsize)
    plot_event.add_data({"label": "ci by Ridge",   "plot_type": "2D", "axis": axes[4], "data": taus_linear_ridge})
    plot_event.add_data({"label": "ci by LASSO",   "plot_type": "2D", "axis": axes[4], "data": taus_linear_lasso})

    ci_lr = abs(np.array([ci_log_ridge[i] for i in range(cparams.ntau)]))
#    ci_lr = abs(np.array([ci_log_ridge[i] / taus[i] for i in range(cparams.ntau)]))
    if do_sp:
        ci_lsp = abs(np.array([ci_log_sp[i] for i in range(cparams.ntau)]))
    ci_ll = abs(np.array([ci_log_lasso[i] for i in range(cparams.ntau)]))
#    ci_ll = abs(np.array([ci_log_lasso[i] / taus[i] for i in range(cparams.ntau)]))
    tau_lr_p = []
    tau_lr_m = []
    ci_lr_p = []
    ci_lr_m = []
    for i in range(cparams.ntau):
        if ci_log_ridge[i] >= 0.0:
            ci_lr_p.append(ci_lr[i])
            tau_lr_p.append(taus_log[i])
        else:
            ci_lr_m.append(ci_lr[i])
            tau_lr_m.append(taus_log[i])

    if do_sp:
        tau_lsp_p = []
        tau_lsp_m = []
        ci_lsp_p = []
        ci_lsp_m = []
        for i in range(cparams.ntau):
            if ci_log_sp[i] >= 0.0:
                ci_lsp_p.append(ci_lsp[i])
                tau_lsp_p.append(taus_log[i])
            else:
                ci_lsp_m.append(ci_lsp[i])
                tau_lsp_m.append(taus_log[i])

    tau_ll_p = []
    tau_ll_m = []
    ci_ll_p = []
    ci_ll_m = []
    for i in range(cparams.ntau):
        if ci_log_lasso[i] >= 0.0:
            ci_ll_p.append(ci_ll[i])
            tau_ll_p.append(taus_log[i])
        else:
            ci_ll_m.append(ci_ll[i])
            tau_ll_m.append(taus_log[i])

    """
    print("")
    print("lr_p", len(tau_lr_p))
    for i in range(len(tau_lr_p)):
        print(f"  {tau_lr_p[i]:8.3g} {ci_lr_p[i]:10.3g}")
    print("lr_m", len(tau_lr_m))
    for i in range(len(tau_lr_m)):
        print(f"  {tau_lr_m[i]:8.3g} {ci_lr_m[i]:10.3g}")
    """
    
#    axes[5].plot(taus_log, ci_lr, label = '|$c_i$| (Ridge)', linestyle = '-', linewidth = 0.5,marker = 'o', markersize = 2.0)
#    axes[5].plot(taus_log, ci_ll, label = '|$c_i$| (LASSO)', linestyle = '-', linewidth = 0.5,marker = 'o', markersize = 2.0)
    axes[5].plot(taus_log,  ci_lr, label = '|$c_i$| (Ridge)', linestyle = '-', linewidth = 0.5, color = 'red')
    taus_log_ridge_p = axes[5].plot(tau_lr_p,  ci_lr_p, linestyle = '', marker = 'o', markersize = 2.0, markerfacecolor = 'red', markeredgecolor = 'red')
    taus_log_ridge_m = axes[5].plot(tau_lr_m,  ci_lr_m, linestyle = '', marker = 'x', markersize = 2.0, markerfacecolor = 'pink', markeredgecolor = 'pink')
    if do_sp:
        axes[5].plot(taus_log,  ci_lsp, label = '|$c_i$| (sp)', linestyle = '-', linewidth = 0.5, color = 'black')
        axes[5].plot(tau_lsp_p, ci_lsp_p, linestyle = '', marker = 'o', markersize = 2.0, markerfacecolor = 'cyan', markeredgecolor = 'cyan')
        axes[5].plot(tau_lsp_m, ci_lsp_m, linestyle = '', marker = 'x', markersize = 2.0, markerfacecolor = 'orange', markeredgecolor = 'orange')
    axes[5].plot(taus_log, ci_ll, label = '|$c_i$| (LASSO)',linestyle = '-', linewidth = 0.5, color = 'blue')
    taus_log_lasso_p = axes[5].plot(tau_ll_p, ci_ll_p, linestyle = '', marker = 'o', markersize = 2.0, markerfacecolor = 'blue', markeredgecolor = 'blue')
    taus_log_lasso_m = axes[5].plot(tau_ll_m, ci_ll_m, linestyle = '', marker = 'x', markersize = 2.0, markerfacecolor = 'green', markeredgecolor = 'green')
    axes[5].set_xlabel("$\\tau$", fontsize = cparams.fontsize)
    axes[5].set_ylabel("|$c_i$|",   fontsize = cparams.fontsize)
    axes[5].set_xscale('log')
    axes[5].set_yscale('log')
    axes[5].legend(fontsize = cparams.legend_fontsize)
#    plot_event.add_data({"label": "ci by Ridge",   "plot_type": "2D", "axis": axes[5], "data": taus_log_ridge_p})
    plot_event.add_data({"label": "ci by Ridge (positive)",   "plot_type": "2D", "axis": axes[5], "data": taus_log_ridge_p,
                                "x": tau_lr_p, 'y': ci_lr_p,
                                "xlist": [tau_lr_p, ci_lr_p], "xlabels": ["tau", "ci"]})
#    plot_event.add_data({"label": "ci by Ridge",   "plot_type": "2D", "axis": axes[5], "data": taus_log_ridge_m})
    ci_lr_m_org = -np.array(ci_lr_m)
    plot_event.add_data({"label": "ci by Ridge (negative)",   "plot_type": "2D", "axis": axes[5], "data": taus_log_ridge_m,
                                "x": tau_lr_m, 'y': ci_lr_m_org,
                                "xlist": [tau_lr_m, ci_lr_m_org], "xlabels": ["tau", "ci"]})
#    plot_event.add_data({"label": "ci by LASSO",   "plot_type": "2D", "axis": axes[5], "data": taus_log_lasso_p})
    plot_event.add_data({"label": "ci by LASSO (positive)",   "plot_type": "2D", "axis": axes[5], "data": taus_log_lasso_p,
                                "x": tau_lr_p, 'y': ci_ll_p,
                                "xlist": [tau_lr_p, ci_ll_p], "xlabels": ["tau", "ci"]})
#    plot_event.add_data({"label": "ci by LASSO (negative)",   "plot_type": "2D", "axis": axes[5], "data": taus_log_lasso_m})
    ci_ll_m_org = -np.array(ci_ll_m)
    plot_event.add_data({"label": "ci by LASSO (negative)",   "plot_type": "2D", "axis": axes[5], "data": taus_log_lasso_m,
                                "x": tau_lr_m, 'y': ci_ll_m_org,
                                "xlist": [tau_ll_m, ci_ll_m_org], "xlabels": ["tau", "ci"]})

    plot_event.register_click(fig)

    plt.tight_layout()
    plt.pause(0.001)


    app.terminate("", usage = app.usage, pause = True)


#==========================================
# Main routine
#==========================================
def main():
    app = tkApplication(suppress_usage = 'options')

    print(f"Initialize parameters")
    initialize(app)
    print(f"Update parameters by command-line arguments")
    update_vars(app)
#    app.cparams.print_parameters()
#    exit()

    if app.cparams.mode == 'plot':
        plot(app)
    else:
        fit(app)

if __name__ == "__main__":
    main()

