import os
import numpy as np
#from numpy import log, log10, sqrt, exp
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt
#from functools import lru_cache


import peaksearch


infile  = 'test/xrd.xlsx'
outfile = 'simulated.xlsx'


def Gaussian(x, x0, whalf, A = None):
#A = 1/whalf * sqrt(ln2 / pi)
    if A is None:
        A = 0.469718639 / whalf
#a = whalf / sqrt(ln2)
    a = whalf / 0.832554611
    X = (x - x0) / a
    return A * exp(-X * X)

def Lorentzian(x, x0, whalf, A = None):
#A = 1/whalf/pi
    if A is None:
        A = 1.0 / whalf / pi
    X = (x - x0) / whalf
    return A / (1.0 + X * X)

# 線形回帰用関数
# I0=Noneの場合、基底関数部分を返す
def peak_func(x, I0, xc, w):
    dx = (x - xc) / w
#    y = Gaussian(x, xc, w, 1.0)
    y = Lorentzian(x, xc, w, 1.0)

    if I0 is None:
        return y
    else:
        return I0 * y

#線形回帰の基底関数。線形係数以外の部分を返す
def lsqfunc(idx, x, const_names, const_values):
    if idx == 0:
        return 1.0
    else:
        xc = const_values[0]
        w  = const_values[1]
        return peak_func(x, None, xc[idx], w[idx])

    return None

#ピークサーチの結果から、フィッティングパラメータ変数を作る
def build_fit_params(xlist, ylist, options):
    nsmooth = options["nsmooth"]
    norder = options["norder"]
    threshold = options["threshold"]
    ydiff1_threshold = options["ydiff1_threshold"]

    print()
    print(f"Peak search")
    xpeaks, inf = peaksearch.peak_search(xlist, ylist, nsmooth, norder, 
                                threshold, ydiff1_threshold, is_print = True)
    ysmooth = inf['ysmooth']
    print(f"  peak list:")
    for i, p in enumerate(xpeaks):
        print(f"  {i:03d}: xc={p[1]:8.4f}  I0={ysmooth[p[0]]:10g}  w={p[2]:8.4f}")

    varname  = ["bg_c0"]
    unit     = [     ""]
    pk_scale = [     ""]
    optid    = [      1]
    linid    = [      0]
    x0       = [    0.0]
    dx       = [    0.1]
    kmin     = [-1.0e10]
    kmax     = [ 1.0e10]
    kpenalty = [    1.0]

    for i in range(len(xpeaks)):
        varname.append (f"I0{i+1}")
        unit.append    ("")
        pk_scale.append("")
        optid.append   (1)
        linid.append   (1)
        x0.append      (ysmooth[xpeaks[i][0]])
        dx.append      (1.0)
        kmin.append    (-1.0e10)
        kmax.append    ( 1.0e10)
        kpenalty.append(1.0)

        varname.append (f"xc{i+1}")
        unit.append    ("")
        pk_scale.append("")
        optid.append   (1)
        linid.append   (0)
        x0.append      (xpeaks[i][1])
        dx.append      (0.5)
        kmin.append    (0.0)
        kmax.append    (180.0)
        kpenalty.append(1.0)

        varname.append (f"w{i+1}")
        unit.append    ("")
        pk_scale.append("")
        optid.append   (1)
        linid.append   (0)
        x0.append      (xpeaks[i][2])
        dx.append      (0.2)
        kmin.append    (0.0)
        kmax.append    (100.0)
        kpenalty.append(1.0)

    return varname, unit, pk_scale, optid, linid, x0, dx, kmin, kmax, kpenalty

# yデータリストylistから、線形回帰部分のy値リストを返す
def cal_linearpart(xlist, ylist):
    return ylist

# 線形回帰の係数を全フィッティングパラメータにマージして返す
def recover_pk_all(pk, ai_all):
    pk_all = pk.copy()
    n_allparams = len(pk)
    pk_all[0] = ai_all[0]
    for i in range(1, n_allparams, 3):
        pk_all[i] = ai_all[i // 3 + 1]

    return pk_all

# 全フィッティングパラメータから線形回帰パラメータを抜き出し、線形回帰を実行
def build_llsq_params(xlist, ylist, varname_all, pk_all, optid_all, linid_all):
    ylin_list = cal_linearpart(xlist, ylist)
    xlin_list = xlist
    n_allparams = len(varname_all)

# optid_combined includes all fitting parameters
    optid_combined = [int(oid and lid) for oid, lid in zip(optid_all, linid_all)]

    varname_lin = [varname_all[0]]
    pk_lin    = [pk_all[0]]
    xc_lin    = [None]
    w_lin     = [None]
    optid_lin = [optid_combined[0]]
    for i in range(1, n_allparams, 3):
        varname_lin.append(varname_all[i])
        optid_lin.append(optid_combined[i])
        xc_lin.append(pk_all[i + 1])
        w_lin.append(pk_all[i + 2])
        pk_lin.append(pk_all[i])

    const_names  = ["xc", "w"]
    const_values = [xc_lin, w_lin]

    return xlin_list, ylin_list, varname_lin, pk_lin, optid_lin, const_names, const_values

#2次元リストの変数x_listから、2次元リストの関数値y_listを返す
def cal_ylist(app, xk_all, x_list, fit = None, run = True, print_level = 1):
    bgc0 = xk_all[0]
    ylist = np.zeros(len(x_list[0]))
    xc1 = xk_all[5]
    print("xc1=", xc1)
    for i, x in enumerate(x_list[0]):
        ylist[i] = bgc0
        for ip in range(1, len(xk_all), 3):
            I0 = xk_all[ip]
            xc = xk_all[ip + 1]
            w  = xk_all[ip + 2]
            ylist[i] += peak_func(x, I0, xc, w)

    return [ylist]

def save_data(path, labels, data_list):
    print()
    print(f"Save data to [{path}]")
    df = pd.DataFrame(np.array(data_list).T, columns = labels)
    df.to_excel(path, index = False, header = True) 


#ファイルを読み込み、xラベル、yラベルの1次元リストと、2次元リストのx値、y値を返す
def read_input_data(app, infile, options = None):
#    xlabels = options.get("xlabels", None)
#    ylabels = options.get("ylabels", None)
    xmin = options.get("xmin", None)
    xmax = options.get("xmax", None)

    if not os.path.isfile(infile): 
        return None, None, None, None, None

    df = pd.read_excel(infile) #, header = True)
    labels = df.columns.tolist()
    data_list = df.values.T
    _xin = data_list[0]
    _yin = data_list[1]

    x = []
    y = []
    for i in range(len(_xin)):
        if not (xmin is None or xmin == '' or xmin == '*') and _xin[i] < xmin:
            continue
        if not (xmax is None or xmax == '' or xmax == '*') and xmax < _xin[i]:
            continue
        x.append(_xin[i])
        y.append(_yin[i])

    xlabels = [labels[0]]
    ylabels = [labels[1]]
    xdatalist = [x]
    ydatalist = [y]
    wdatalist = None

    return xlabels, ylabels, xdatalist, ydatalist, wdatalist

#入力データの描画
def plot_input(app, cfg, mf):
    print()
    print("Plot input data")
    print(f"infile   : {cfg.infile}")

#    fit = mf.init_fit(app, cfg)

    print()
    xlabels, ylabels, xdatalist, ydatalist, wdatalist = read_input_data(app, cfg.infile, options = {})
#    fit.x_labels, fit.y_labels, fit.xd_list, fit.yd_list, fit.w_list = mf.read_input_data(app, cfg.infile, cfg)
#    fit.labels = fit.x_labels + fit.y_labels
#    fit.x_list = fit.xd_list
#    fit.y_list = fit.yd_list
    x_label = xlabels[0]
    y_label = ylabels[0]
    x_list = xdatalist[0]
    y_list = ydatalist[0]

    fig, axes = plt.subplots(1, 1, figsize = cfg.figsize, dpi = 100, tight_layout = True)
    ax = axes
    ax.tick_params(labelsize = cfg.fontsize)

    line, = ax.plot(x_list, y_list, label = "input", linewidth = 1.0, marker = "o", markersize = 1.0)
    ax.set_xlabel(x_label, fontsize = cfg.fontsize)
    ax.set_ylabel(y_label, fontsize = cfg.fontsize)
    ax.legend(fontsize = cfg.legend_fontsize)

    plt.pause(0.01)
    app.terminate(pause = True)

#ピークサーチ
def peak_search(app, cfg, mf):
    krange = 10.0
    cfg.figsize_test  = [12, 8]
    cfg.fontsize_test = 12

    print("")
    print(f"Peak search in the data [{cfg.infile}]")
    print("mode    : ", cfg.mode)
    print("infile  : ", cfg.infile)
    print("xlabel  : ", cfg.xlabel)
    print("ylabel  : ", cfg.ylabel)
    print("  x range : ", cfg.xmin, cfg.xmax)

    print()
    xlabels, ylabels, xdatalist, ydatalist, wdatalist = read_input_data(app, infile, options = {})

    '''
    fit = mf.init_fit(app, cfg)
    fit.x_labels, fit.y_labels, fit.xd_list, fit.yd_list, fit.w_list = mf.read_input_data(app, cfg.infile, cfg)
    x = fit.xd_list[0]
    y = fit.yd_list[0]
    xlabel = fit.x_labels[0]
    ylabel = fit.y_labels[0]
    '''

    xlabel = xlabels[0]
    ylabel = ylabels[0]
    x = xdatalist[0]
    y = ydatalist[0]
    ndata = len(x)

    print("norder    : ", cfg.norder)
    print("nsmooth   : ", cfg.nsmooth)
    if cfg.nsmooth % 2 == 0:
        cfg.nsmooth += 1
        print(f"  Warning: nsmooth must be odd. Changed to {cfg.nsmooth}")
    print("threshold : ", cfg.threshold)
    print("dy/dx threshold : ", cfg.ydiff1_threshold)
    print("  ndata = ", ndata)

    xpeaks, inf = peaksearch.peak_search(x, y, cfg.nsmooth, cfg.norder, cfg.threshold, cfg.ydiff1_threshold, is_print = True)
    ysmooth = inf['ysmooth']

#=============================
# prepare graph
#=============================
    print("")
    print("plot")
    ndata = len(x)
    dx = x[1] - x[0]

    def plot_input(ax_input):
        maxI = max(y)
        bar_range = [-0.05 * maxI, -0.01 * maxI]

        ax_input.plot(x, y, label = 'input', linestyle = '', marker = 'o', markersize = 0.5, markerfacecolor = 'black', markeredgecolor = 'black')
        ax_input.plot(ax_input.get_xlim(), [0.0, 0.0], linestyle = 'dashed', color = 'red', linewidth = 0.5)
        ylim = ax_input.get_ylim()
        for i in range(len(xpeaks)):
            idx = xpeaks[i][0]
            _x = x[idx]
            _I = ysmooth[idx]
            _w = xpeaks[i][2]
            _Ihalf = _I / 2.0
            nx = int(_w / dx * krange + 1.00001)
            xx = [x[i1] for i1 in range(max([0, idx - nx]), min([idx + nx, ndata]))]
#            a_g = _w / 0.832554611
#            gf = [Gauss(xx[i1], _x, a_g, _I) for i1 in range(len(xx))]
            gf = [peak_func(xx[i1], _I, _x, _w) for i1 in range(len(xx))]

            ax_input.plot([_x, _x], bar_range, linestyle = '-', color = 'black', linewidth = 0.5)
            ax_input.plot([_x - _w, _x + _w], [_Ihalf, _Ihalf], linestyle = '-', color = 'green', linewidth = 0.5)
            ax_input.plot(xx, gf, linestyle = '-', color = 'red', linewidth = 0.5)
#        ax_input.set_xlabel(xlabel, fontsize = cfg.fontsize)
        ax_input.set_ylabel(ylabel, fontsize = cfg.fontsize)
        ax_input.legend(fontsize = cfg.legend_fontsize)

    fig, axes = plt.subplots(1, 1, sharex = 'all', figsize = cfg.figsize)
    axes.tick_params(labelsize = cfg.fontsize)
    plot_input(axes)
    axes.set_xlabel(xlabel, fontsize = cfg.fontsize)

    plt.tight_layout()

    plt.pause(0.01)
    app.terminate(pause = True)

def main():
    class tkParams():
        pass

    class tkApplication():
        def __init__(self):
            pass
    
        def terminate(self, pause = False):
            input("Press ENTER to terminate>>")


    app = tkApplication()
    cfg = tkParams()
    cfg.infile = infile
    cfg.mode = "plot"
    cfg.xlabel = 0
    cfg.ylabel = 1
    cfg.norder = 5
    cfg.nsmooth = 7
    cfg.xmin = 20.0
    cfg.xmax = 40.0
    cfg.threshold = 3000
    cfg.ydiff1_threshold = 1.0
    cfg.figsize = (8, 6)
    cfg.fontsize = 12
    cfg.legend_fontsize = 12

    print()
    print("Peak fit test program / odata module")
    print(f"infile={infile}")
    print(f"outfile={outfile}")

    xlabels, ylabels, xdatalist, ydatalist, wdatalist = read_input_data(app, infile, options = {})
    save_data(outfile, xlabels + ylabels, xdatalist + ydatalist)
#    plot_input(app, cfg, mf = None)
    peak_search(app, cfg, mf = None)

if __name__ == '__main__':
    main()
