import os
import numpy as np
from numpy import log, log10, sqrt, exp
from scipy.signal import savgol_filter 
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt
#from functools import lru_cache


import decay


infile  = 'test/decay.xlsx'
outfile = 'simulated.xlsx'


class tkParams():
    pass
    

#線形回帰の基底関数。線形係数以外の部分を返す
def lsqfunc(idx, x, const_names, const_values):
    if idx == 0:
        return 1.0
    else:
        tau = const_values[0]
        return exp(-x / tau[idx])

    return None

#ピークサーチの結果から、フィッティングパラメータ変数を作る
def build_fit_params(xdata_list, ydata_list, options):
    ntau    = options["ntau"]
    nsmooth = options["nsmooth"]
    norder  = options["norder"]
    xX = xdata_list[0]
    yY = ydata_list[0]
    ndata = len(xX)
    
# Estimation of tau(t)
    print()
    print(f"Rough estimation of tau(x) by single decay model for [{ntau}] points")

    maxx = max(xX)
    minx2 = max([min(xX), xX[1]])
    maxy = max(yY)
    miny = min([0.0, min(yY) - 1.0e-10])
    logy = [log(yY[i] - miny) for i in range(ndata)]
    tau_smoothed = savgol_filter(logy, nsmooth, norder, deriv = 1)
    tau_smoothed = -1.0 / tau_smoothed

    tau0 = -1.0 / (logy[1] - logy[0]) * (xX[1] - xX[0])
    tau1 = -1.0 / (logy[ndata - 1] - logy[ndata - 2]) * (xX[ndata - 1] - xX[ndata - 2])
    print(f"  tau range from gradients: {tau0} - {tau1}")

    nLSQ_tau = options.get("nLSQ_tau", 5)
    cparams = tkParams()
    cparams.b0 = None
    cparams.A0 = None
    cparams.tau = None
    cparams.method = 'nelder-mead'
    cparams.tol    = 1.0e-5 #maxy * 1.0e-3
    cparams.nmaxiter = 100
    x_tau, tau_tau, tau_b0, xs_tau, ys_tau = decay.estimate_taus(nLSQ_tau, xX, yY, cparams, print_level = 0)
    print(f"  tau range from local LSQ: {min(tau_tau)} - {max(tau_tau)}")
#    print(f"  tau list:", tau_list)
    print(f"Use tau range from local LSQ")
    tau0 = min(tau_tau)
    tau1 = max(tau_tau)

    tau_list = []
    logtau0 = log(tau0)
    logtau1 = log(tau1)
    for logtau in np.arange(logtau0, logtau1, (logtau1 - logtau0) / (ntau - 1)):
        tau_list.append(exp(logtau))
    tau_list.append(tau1)
    print(f"  tau list:", tau_list)

    varname  = ["bg_c0"]
    unit     = [     ""]
    pk_scale = [     ""]
    optid    = [      1]
    linid    = [      1]
    x0       = [    0.0]
    dx       = [maxy * 0.01]
    kmin     = [    0.0]
    kmax     = [   maxy]
    kpenalty = [    1.0]

    for i in range(ntau):
        varname.append (f"I0{i+1}")
        unit.append    ("")
        pk_scale.append("")
        optid.append   (1)
        linid.append   (1)
        x0.append      (0.0)
        dx.append      (maxy * 0.1)
        kmin.append    (0.0)
        kmax.append    (maxy * 10.0)
        kpenalty.append(1.0)

        varname.append (f"tau{i+1}")
        unit.append    ("")
        pk_scale.append("")
        optid.append   (1)
        linid.append   (0)
        x0.append      (tau_list[i])
        dx.append      (tau_list[i] * 0.1)
        kmin.append    (1.0e-3 / abs(maxx))
        kmax.append    (1.0e3 / abs(minx2))
        kpenalty.append(1.0)

    return varname, unit, pk_scale, optid, linid, x0, dx, kmin, kmax, kpenalty

# yデータリストylistから、線形回帰部分のy値リストを返す
def cal_linearpart(xlist, ylist):
    return ylist

# 線形回帰の係数を全フィッティングパラメータにマージして返す
def recover_pk_all(pk, ai_all):
    pk_all = pk.copy()
    n_allparams = len(pk)
    pk_all[0] = ai_all[0]
    for i in range(1, n_allparams, 2):
        pk_all[i] = ai_all[i // 2 + 1]

    return pk_all

# 全フィッティングパラメータから線形回帰パラメータを抜き出し、線形回帰を実行
def build_llsq_params(xlist, ylist, varname_all, pk_all, optid_all, linid_all):
    ylin_list = cal_linearpart(xlist, ylist)
    xlin_list = xlist
    n_allparams = len(varname_all)

# optid_combined includes all fitting parameters
    optid_combined = [int(oid and lid) for oid, lid in zip(optid_all, linid_all)]

    varname_lin = [varname_all[0]]
    pk_lin    = [pk_all[0]]
    tau_lin   = [None]
    optid_lin = [optid_combined[0]]
    for i in range(1, n_allparams, 2):
        varname_lin.append(varname_all[i])
        optid_lin.append(optid_combined[i])
        tau_lin.append(pk_all[i + 1])
        pk_lin.append(pk_all[i])

    const_names  = ["tau"]
    const_values = [tau_lin]

    return xlin_list, ylin_list, varname_lin, pk_lin, optid_lin, const_names, const_values

#2次元リストの変数x_listから、2次元リストの関数値y_listを返す
def cal_ylist(app, xk_all, x_list, fit = None, run = True, print_level = 1):
    bgc0 = xk_all[0]
#    print("xk_all=", xk_all)
    ylist = np.zeros(len(x_list[0]))
    for i, x in enumerate(x_list[0]):
        ylist[i] = xk_all[0]
        for ip in range(1, len(xk_all), 2):
            I0 = xk_all[ip]
            tau = xk_all[ip + 1]
            ylist[i] += I0 * exp(-x / tau)

    return [ylist]

def save_data(path, labels, data_list):
    print()
    print(f"Save data to [{path}]")
    df = pd.DataFrame(np.array(data_list).T, columns = labels)
    df.to_excel(path, index = False, header = True) 


#ファイルを読み込み、xラベル、yラベルの1次元リストと、2次元リストのx値、y値を返す
def read_input_data(app, infile, options = {}):
#    xlabels = options.get("xlabels", None)
#    ylabels = options.get("ylabels", None)
    xmin = options.get("xmin", None)
    xmax = options.get("xmax", None)

    if not os.path.isfile(infile): 
        return None, None, None, None, None

    df = pd.read_excel(infile) #, header = True)
    labels = df.columns.tolist()
    data_list = df.values.T
    _xin = data_list[0]
    _yin = data_list[1]

    x = []
    y = []
    w = []
    wx = app.cfg.get("wx", "1/x")
    for i in range(len(_xin)):
        if not (xmin is None or xmin == '' or xmin == '*') and _xin[i] < xmin:
            continue
        if not (xmax is None or xmax == '' or xmax == '*') and xmax < _xin[i]:
            continue
        x.append(_xin[i])
        y.append(_yin[i])
        if wx == 'constant':
            w.append(1.0)
        elif wx == '1/x':
            if _xin[i] == 0.0:
                w.append(1.0 / abs(_xin[i+1]))
            else:
                w.append(1.0 / abs(_xin[i]))
        elif wx == '1/x':
            if _xin[i] == 0.0:
                w.append(1.0 / sqrt(abs(_xin[i+1])))
            else:
                w.append(1.0 / sqrt(abs(_xin[i])))
        elif wx == 'x':
            w.append(1.0 / abs(_xin[i]))
        else:
            print()
            app.terminate(f"Error in decay_model.read_input_data(): Invalid wx=[{wx}]\n", pause = True)

    nx_skip = app.cfg.get("nx_skip", 0)
    ndata = len(x)
    ndata10p = int(ndata * 0.1)
    xlist = []
    ylist = []
    wlist = []
    for i, (_x, _y, _w) in enumerate(zip(x, y, w)):
        if nx_skip > 0 and i > ndata10p and (i - ndata10p) % nx_skip != 0:
            continue

        xlist.append(_x)
        ylist.append(_y)
        wlist.append(_w)

    xlabels = [labels[0]]
    ylabels = [labels[1]]
    xdatalist = [xlist]
    ydatalist = [ylist]
    wdatalist = [wlist]
    
    return xlabels, ylabels, xdatalist, ydatalist, wdatalist

#入力データの描画
def plot_input(app, cfg, mf):
    print()
    print("Plot input data")
    print(f"infile   : {cfg.infile}")

    print()
    xlabels, ylabels, xdata_list, ydata_list, wdata_list = read_input_data(app, cfg.infile, options = {})

    xlist = xdata_list[0]
    ylist = ydata_list[0]
    x_label = xlabels[0]
    y_label = ylabels[0]

    fig, axes = plt.subplots(1, 1, figsize = cfg.figsize, dpi = 100, tight_layout = True)
    ax = axes
    ax.tick_params(labelsize = cfg.fontsize)

    line, = ax.plot(xlist, ylist, label = "input", linewidth = 1.0, marker = "o", markersize = 1.0)
    ax.set_xlabel(x_label, fontsize = cfg.fontsize)
    ax.set_ylabel(y_label, fontsize = cfg.fontsize)
    ax.legend(fontsize = cfg.legend_fontsize)

    plt.pause(0.01)
    app.terminate(pause = True)



def main():
    class tkParams():
        pass

    class tkApplication():
        def __init__(self):
            pass
    
        def terminate(self, pause = False):
            input("Press ENTER to terminate>>")


    app = tkApplication()
    cfg = tkParams()
    cfg.infile = infile
    cfg.mode = "plot"
    cfg.xlabel = 0
    cfg.ylabel = 1
    cfg.xmin = 20.0
    cfg.xmax = 40.0
    cfg.figsize = (8, 6)
    cfg.fontsize = 12
    cfg.legend_fontsize = 12

    print()
    print("Peak fit test program / odata module")
    print(f"infile={infile}")
    print(f"outfile={outfile}")

    xlabels, ylabels, xdatalist, ydatalist, wdatalist = read_input_data(app, infile, options = {})
    save_data(outfile, xlabels + ylabels, xdatalist + ydatalist)
    plot_input(app, cfg, mf = None)


if __name__ == '__main__':
    main()
