import os
from pprint import pprint
import numpy as np
from numpy import log, log10, sqrt, exp
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt


infile  = 'test/peaks.xlsx'
outfile = 'simulated.xlsx'


#線形回帰の基底関数。線形係数以外の部分を返す
def lsqfunc(idx, x, xc, w):
    if idx == 0:
        return 1.0
    elif idx == 1:
        return x
    else:
        dx = (x - xc) / w
        return exp(-dx * dx)

#2次元リストの変数x_listから、2次元リストの関数値y_listを返す
def cal_ylist(app, xk_all, x_list, fit = None, run = True, print_level = 1):
    bgc0, bgc1, I0, xc, w = xk_all
#    print("xk_all=", xk_all)
    ylist = np.zeros(len(x_list[0]))
    for i, x in enumerate(x_list[0]):
        dx = (x - xc) / w
        ylist[i] = bgc0 * lsqfunc(0, x, xc, w) \
                 + bgc1 * lsqfunc(1, x, xc, w) \
                 + I0   * lsqfunc(2, x, xc, w) 

    return [ylist]

def save_data(path, labels, data_list):
    print()
    print(f"Save data to [{path}]")
    df = pd.DataFrame(np.array(data_list).T, columns = labels)
    df.to_excel(path, index = False, header = True) 

#ファイルを読み込み、xラベル、yラベルの1次元リストと、2次元リストのx値、y値を返す
def read_input_data(app, infile, options = {}):
#    xlabels = options.get("xlabels", None)
#    ylabels = options.get("ylabels", None)
    xmin = options.get("xmin", None)
    xmax = options.get("xmax", None)

    if not os.path.isfile(infile): 
        return None, None, None, None, None

    df = pd.read_excel(infile) #, header = True)
    labels = df.columns.tolist()
    data_list = df.values.T

    return ['x'], ['y'], [data_list[0]], [data_list[1]], None

#入力データの描画
def plot_input(app, cfg, mf):
    print()
    print("Plot input data")
    print(f"infile   : {cfg.infile}")

    print()
    xlabels, ylabels, xdata_list, ydata_list, wdata_list = read_input_data(app, cfg.infile, options = {})

    xlist = xdata_list[0]
    ylist = ydata_list[0]
    x_label = xlabels[0]
    y_label = ylabels[0]

    fig, axes = plt.subplots(1, 1, figsize = cfg.figsize, dpi = 100, tight_layout = True)
    ax = axes
    ax.tick_params(labelsize = cfg.fontsize)

    line, = ax.plot(xlist, ylist, label = "input", linewidth = 1.0, marker = "o", markersize = 1.0)
    ax.set_xlabel(x_label, fontsize = cfg.fontsize)
    ax.set_ylabel(y_label, fontsize = cfg.fontsize)
    ax.legend(fontsize = cfg.legend_fontsize)

    plt.pause(0.01)
    app.terminate(pause = True)
    
    
def main():
    class tkParams():
        pass

    class tkApplication():
        def __init__(self):
            pass
    
        def terminate(self, pause = False):
            input("Press ENTER to terminate>>")


    app = tkApplication()
    cfg = tkParams()
    cfg.infile = infile
    cfg.mode = "plot"
    cfg.xlabel = 0
    cfg.ylabel = 1
    cfg.xmin = 20.0
    cfg.xmax = 40.0
    cfg.figsize = (8, 6)
    cfg.fontsize = 12
    cfg.legend_fontsize = 12

    print()
    print("Peak fit test program / odata module")
    print(f"infile={infile}")
    print(f"outfile={outfile}")

    xlabels, ylabels, xdatalist, ydatalist, wdatalist = read_input_data(app, infile, options = {})
    save_data(outfile, xlabels + ylabels, xdatalist + ydatalist)
    plot_input(app, cfg, mf = None)


if __name__ == '__main__':
    main()