import os
import numpy as np
from numpy import log, log10, sqrt, exp
from scipy.signal import savgol_filter 
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt
from functools import lru_cache


from tklib.tkparams import tkParams
from tklib.tksci.tksci import e, kB
import mobility_pi as cmu


infile   = 'decay.xlsx'
datafile = 'simulated.xlsx'


#線形回帰の基底関数。線形係数以外の部分を返す
def lsqfunc(idx, T, const_names, const_values):
#def lsqfunc(Eop, pi, idx, T):
    if idx == 0:
        if T == 0.0:
            return 0.0
        else:
            Eop = const_values[3][0] # [pi_lin, Eb,   s_phi,   Eop]
            k = Eop * e / kB / T
            if k > 70.0:
                return 0.0
            else:
                return 1.0 / (exp(k) - 1.0)
    else:
        pi_lin = const_values[0] # [pi_lin, Eb,   s_phi,   Eop]
        ip = idx - 1
        if ip >= len(pi_lin):
            nfunc = len(pi_lin)
            print(f"\nError in lsqfunc: Too many functions required (idx = {idx}, nfunc={nfunc})\n")
            exit()

        y = pow(T, -pi_lin[ip + 1])

        return y


# yデータリストylistから、線形回帰部分のy値リストを返す
def cal_linearpart(xlist, ylist):
    return ylist

# 線形回帰の係数を全フィッティングパラメータにマージして返す
def recover_pk_all(pk, ai_all):
    pk_all = pk.copy()
    n_allparams = len(pk)

    print("pk=", pk)
    print("ai_all=", ai_all)
    pk_all[3] = ai_all[3]
    for i in range(3, n_allparams, 2):
         pk_all[i] = ai_all[i // 2 - 1]
#         print(f"pk_all[{i}] = ai_all[{i // 2 - 1}]")

    return pk_all

# 全フィッティングパラメータから線形回帰パラメータを抜き出し、線形回帰を実行
def build_llsq_params(xlist, ylist, varname_all, pk_all, optid_all, linid_all):
# optid_combined includes all fitting parameters
    optid_combined = [int(oid and lid) for oid, lid in zip(optid_all, linid_all)]

    Eb, s_phi, Eop, aop, pi, ai, aop_id, ai_id = cmu.convert_xk(pk_all, optid_combined)
    ylin_list = cmu.cal_linearpart(xlist, ylist, Eb, s_phi)
#    ylin_list = cal_linearpart(xlist, ylist)
    xlin_list = xlist
    n_allparams = len(varname_all)

# 0     1   2   3  4  5  6  7  8   9 10 11 12 13
#VB sigma0 Eop aop p1 a1 p2 a2 p3 a3 p4 a4 p5 a5 
    varname_lin = [varname_all[3]]
    pk_lin    = [pk_all[3]]
    pi_lin   = [None]
    optid_lin = [optid_combined[3]]
    for i in range(5, n_allparams, 2):
        varname_lin.append(varname_all[i])
        optid_lin.append(optid_combined[i])
        pk_lin.append(pk_all[i])
        pi_lin.append(pk_all[i - 1])

    const_names  = ["pi",   "Eb", "s_phi", "Eop"]
    const_values = [pi_lin, [Eb], [s_phi], [Eop]]

    return xlin_list, ylin_list, varname_lin, pk_lin, optid_lin, const_names, const_values


_global = tkParams()

#2次元リストの変数x_listから、2次元リストの関数値y_listを返す
def cal_ylist(app, xk_all, x_list, fit = None, run = True, print_level = 1):
    _global.app = app
    _global.x_list = x_list
    _global.fit = fit
    _global.run = run
    _global.print_level = print_level

    return cal_ylist_cachable(tuple(xk_all))

@lru_cache(maxsize = None)
def cal_ylist_cachable(xk_all):
    app = _global.app
    x_list = _global.x_list
    fit = _global.fit
    run = _global.run
    print_level = _global.print_level

    T_list = x_list[0]
    Eb, s_phi, Eop, aop, pi, ai = cmu.convert_xk(xk_all)
    mu_tot, mu_KGB, mu_ingrain, mu_op, mu_pi = cmu.cal_mu_list(T_list, Eb, s_phi, Eop, aop, pi, ai, rettype = 'all')

    return [mu_tot]

#外部プログラムを呼び出す場合の例
    cmd = get_cmd(app, xk, fit)
    if run:
        print(f"Run [{cmd}]")
        os.system(cmd)
    else:
        print(f"will not run [{cmd}]")

    labels, x_list, y_list = read_data(fit.datafile, None, fit.cfg.sample, fit.cfg.Tmin, fit.cfg.Tmax, fit.cfg.Tmin, fit.cfg.Nmax, 'calc')
    return [y_list[1]]

def save_data(path, labels, data_list):
    print()
    print(f"Save data to [{path}]")
    df = pd.DataFrame(np.array(data_list).T, columns = labels)
    df.to_excel(path, index = False, header = True) 


#ファイルを読み込み、xラベル、yラベルの1次元リストと、2次元リストのx値、y値を返す
def read_input_data(app, infile, options = None):
    xlabels = options.get("xlabels", None)
    ylabels = options.get("ylabels", None)
    Tmin = options.get("Tmin", None)
    Tmax = options.get("Tmax", None)
    sample = options.get("sample", "")

    print(f"  target samples: {sample}")
    print(f"  T range: {Tmin} - {Tmax} K")

    xlabels, ylabels, T_all, mu_all = cmu.read_data(app.fit, infile, xlabels[0], ylabels[0])
    T_list = []
    mu_list = []
    for T, mu in zip(T_all, mu_all):
        if T < Tmin or Tmax < T: continue
        
        T_list.append(T)
        mu_list.append(mu)

    print()
    print("T, mu")
    for T, mu in zip(T_list, mu_list):
        print(f"{T:8.4g}\t{mu:10.4g}")

    return ['T (K)'], ['mu(obs) (cm2/Vs)'], [T_list], [mu_list], None

#入力データの描画
def plot_input(app, cfg, mf):
    print()
    print("Plot input data")
    print(f"infile   : {cfg.infile}")

    fit = mf.init_fit(app, cfg)

    print()
    fit.x_labels, fit.y_labels, fit.xd_list, fit.yd_list, fit.w_list = mf.read_input_data(app, cfg.infile, cfg)

    fit.labels = fit.x_labels + fit.y_labels
    fit.x_list = fit.xd_list
    fit.y_list = fit.yd_list
    x_label = fit.x_labels[0]
    y_label = fit.y_labels[0]

    fig, axes = plt.subplots(1, 1, figsize = cfg.figsize, dpi = 100, tight_layout = True)
    ax = axes
    ax.tick_params(labelsize = cfg.fontsize)

    line, = ax.plot(fit.x_list[0], fit.y_list[0], label = "input", linewidth = 1.0, marker = "o", markersize = 1.0)
    ax.set_xlabel(x_label, fontsize = cfg.fontsize)
    ax.set_ylabel(y_label, fontsize = cfg.fontsize)
    ax.legend(fontsize = cfg.legend_fontsize)

    plt.pause(0.01)
    app.terminate(pause = True)



def main():
    pass


if __name__ == '__main__':
    main()
