import os
import numpy as np
from numpy import sin, cos, tan, pi, exp, log, sqrt
from scipy.optimize import minimize
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib.widgets as wg


from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg, pconv_by_type
#analyze_varstr, save_csv
from tklib.tkutils import is_exist, is_file, is_dir
from tklib.tkinifile import tkIniFile
from tklib.tkparams import tkParams
from tklib.tkgraphic.tkplotevent import tkPlotEvent

from tklib.tkvariousdata import tkVariousData
from tklib.tkfilter import tkFilter
from tklib.tksci.tksci import pi, h, hbar,c, e, kB, me, Ea_Arrhenius, log10
from tklib.tksci.tkoptimize import mlsq_general
from tklib.tksci.tkFit_object import tkFit_object


'''
Fitting and interactive matplotlb helper library
'''


class tkFit_mxy(tkFit_object):
    def __init__(self, parameter_file = None, app = None, 
                    method = 'nelder-mead', jac = '3-points', tol = 1.0e-5, nmaxiter = 100, nmaxcall = 1000,
                    fplot = 1, fhistory = 0, ffitfiles = 0, print_interval = 1, plot_interval = 1,
                    xlsm_template = None, **args):
        super(tkParams, self).__init__(**args)
#        super(tkParams, self).__init__(parameter_file, app, args)
#        self.update(**args)

        self.plot_event = tkPlotEvent(plt)
        self.xlabels   = None
        self.x_list    = None
        self.ylabels   = None
        self.y_list    = None
        self.y_convert = None
        self.w_list    = None
        self.func      = None
        self.iter      = 0
        self.f         = None

        self.ix_plot   = ''

        self.varname    = []
        self.unit       = []
        self.pk         = []
        self.pk_convert = []
        self.dx         = []
        self.optid      = []
        self.kmin       = []
        self.kmax       = []
        self.kpenalty   = []
        
        self.method         = method
        self.jac            = jac
        self.tol            = tol
        self.nmaxiter       = nmaxiter
        self.nmaxcall       = nmaxcall
        self.icall          = 0
        self.print_interval = print_interval
        self.plot_interval  = plot_interval
        self.fplot          = fplot
        self.fhistory       = fhistory
        self.ffitfiles      = ffitfiles
        self.stop_flag      = False

        self.xlsm_template = xlsm_template
        

    def read_data_from_filter(self, infile, module_dir, module_name, imodule = 0, iregion = 0,
                x_indexes = [], y_indexes = [], app = None, cparams = None, is_print = False):
#        print("")
#        print(f"Load module by {module_dir}/{module_name}.py")
        filter = tkFilter(app = app, cparams = cparams, plugin_dir = module_dir, module_file = f"{module_name}.py")
        self.module_names, self.modules = filter.load(target = "read_data", is_print = True)
        self.inf_input = filter.read_data(imodule, infile, app = app, cparams = cparams, is_print = is_print)

        self.xlabels = []
        self.x_list  = []
        for idx in x_indexes:
            label, data = filter.find_data_array(idx)
            self.xlabels.append(label)
            self.x_list.append(data)

        self.ylabels = []
        self.y_list = []
        for idx in y_indexes:
            label, data = filter.find_data_array(idx)
            self.ylabels.append(label)
            self.y_list.append(data)


    def read_data(self, infile, xlabels = [0], ylabel = 1, xmins = [None], xmaxs = [None], usage = None,
                module_dir = None, module_name = None, imodule = 0, iregion = 0,
                x_indexes = [], y_indexes = [], app = None, cparams = None, is_print = False):
        if module_dir is not None and module_name is not None:
            return self.read_data_from_filter(infile = infile, 
                    module_dir = module_dir, module_name = module_name, imodule = imodule, iregion = iregion,
                    x_indexes = x_indexes, y_indexes = y_indexes, app = app, cparams = cparams, is_print = is_print)

        print("")
        print(f"Read [{infile}]")
        datafile = tkVariousData(infile)
        labels, datalist = datafile.Read_minimum_matrix(close_fp = True, force_numeric = False, usage = usage)
        
        _xlabels = []
        xins     = []
        for l in xlabels:
            xlabel, xin = datafile.FindDataArray(l, flag = 'i')
            print("x l=", l, xlabel)
            _xlabels.append(xlabel)
            xins.append(xin)
#        xlabel, xin = datafile.FindDataArray(xlabel, flag = 'i')
        ylabel, yin = datafile.FindDataArray(ylabel, flag = 'i')

        self.datafile = datafile
        self.labels   = labels
        self.datalist = datalist
        self.xlabels  = _xlabels
        self.ylabel   = ylabel

        self.nx        = len(xlabels)
        self.ndata_all = len(xin)
#        print("ndata_all=", self.ndata_all)
        self.x_list = []
        for il in range(self.nx):
            self.x_list.append([])
        self.y = []
        self.included_index = []
        for i in range(self.ndata_all):
            is_included = True
            for il in range(len(_xlabels)):
                if xmins[il] is not None and xmins[il] > xins[il][i]:
                    is_included = False
                    continue
                if xmaxs[il] is not None and xmaxs[il] < xins[il][i]:
                    is_included = False
                    continue

                if is_included:
                    self.x_list[il].append(xins[il][i])

            if is_included:
                self.y.append(yin[i])
                self.included_index.append(i)

        self.ndata = len(self.x_list[0])
        self.index = range(self.ndata)

    def cal_ylist(self, pk, x_list = None):
        if x_list is None:
            x_list = self.x_list

        ny = len(self.y_list)

#        print("205 func: len(pk)=", len(pk))
        nvars = len(x_list)
        ndata = len(x_list[0])
        y_list = []
        for i in range(ny):
            y_list.append([])

        for i in range(ndata):
# i番目のxの値セット xs を取り出す
            xs = [x_list[j][i] for j in range(nvars)]
# xs に対するyのリストを取り出す
            ret = self.func(xs, pk)
            for i, v in enumerate(ret):
                y_list[i].append(v)

        return y_list

    def minimize_func(self, pk, x_list = None, y_list = None, w_list = None):
        if x_list is None:
            x_list = self.x_list
        if y_list is None:
            y_list = self.y_list
        if w_list is None:
            w_list = self.w_list

        ny = len(y_list)
        if w_list is None:
            w_list = []
            for i in range(ny):
                w_list.append(1.0)
        if self.y_convert is None:
            self.y_convert = []
            for i in range(ny):
                self.y_convert.append('')
        if self.pk_convert is None:
            self.pk_convert = []
            for i in range(ny):
                self.pk_convert.append('')

        def convert_val(v_list, conv_list, forward = True):
            ret = []
            for v, c in zip(v_list, conv_list):
                if c == 'log':
                    if forward:
                        if v > 0.0:
                            v = log(v)
                        elif v < 0.0:
                            v = -log(v)
                        else:
                            v = -100.0
                    else:
                        v = exp(v)
                ret.append(v)
            return ret

        recoverpk = self.recover_parameters(pk, set_member = False)

        f = 0.0
        kp   = self.kpenalty
        p_tot = 0.0
        if kp is not None:
            kmin = self.kmin
            kmax = self.kmax
            nkp = len(kp)
            for i in range(nkp):
                if recoverpk[i] < kmin[i]:
                    d = recoverpk[i] - kmin[i]
                    p = kp[i] * d * d
                    p_tot += p
                    recoverpk[i] = kmin[i]
                    print(f"**Warning: [{self.varname[i]}] is smaller than [{kmin[i]}]. Add penalty [{p:10.3g}] to f")
                elif recoverpk[i] > kmax[i]:
                    d = recoverpk[i] - kmax[i]
                    p = kp[i] * d * d
                    p_tot += p
                    recoverpk[i] = kmax[i]
                    print(f"**Warning: [{self.varname[i]}] is larger than [{kmax[i]}]. Add penalty [{p:10.3g}] to f")

        nx    = len(x_list)
        ndata = len(x_list[0])
        ny    = len(y_list)
        for i in range(ndata):
            xs = [x_list[j][i] for j in range(nx)]
            ycal_list = self.func(xs, recoverpk)
            yobs_list = [y_list[j][i] for j in range(ny)]
            yobs_list = convert_val(yobs_list, self.y_convert)
            ycal_list = convert_val(ycal_list, self.y_convert)
            for j in range(ny):
                d = ycal_list[j] - yobs_list[j]
                f += w_list[j] * d * d

        f = sqrt(f / ndata) + p_tot

        return f

    def minimize(self, method = None, jac = None, tol = None, nmaxiter = None, initial_simplex = None):
        if method is None:
            method = self.method
        if jac is None:
            jac = self.get('jac', '3-point')
        if tol is None:
            tol = self.tol
        if nmaxiter is None:
            nmaxiter = self.nmaxiter

        print("")
        print("Optimizing parameters:")
        optpk = self.extract_parameters()
        print("   method  =", method)
        print("   nmaxiter=", nmaxiter)
        print("   tol     =", tol)
        print("   optpk=", optpk)
        print("")
        print("Start minimization:")
        if method == 'nelder-mead':
            if initial_simplex and len(initial_simplex) > 0:
                options = {'maxiter': nmaxiter, "disp": True, "initial_simplex": initial_simplex}
            else:
                options = {'maxiter': nmaxiter, "disp": True}
            res = minimize(self.minimize_func, optpk, 
                    method = method, 
                    callback = self.callback,
                    tol = tol, options = options)
        else:
            res = minimize(self.minimize_func, optpk, 
                    method = method, jac = jac, 
                    callback = self.callback,
                    tol = tol, options = {'maxiter': nmaxiter, "disp": True})

        self.finalpk = self.recover_parameters(res.x)
        self.iter = res.nit
        if self.get('fun', None) is not None:
            self.ffin = res.fun
        else:
            self.ffin = self.minimize_func(res.x)
#            self.ffin = self.minimize_func(self.finalpk)
            
        return self.finalpk, res.fun, res.success, res

    def callback(self, pk):
        if self.stop_flag:
            return False
    
        fplot = hasattr(self, 'fplot') and self.fplot
        
        w = self.plt.get_current_fig_manager().window
        if fplot and hasattr(self, 'window') and w != self.window:
            self.stop_flag = True
            return False

        recoverpk = self.recover_parameters(pk, set_member = False)

        if self.iter % self.print_interval == 0:
            print(f"iter: {self.iter}")
            n = len(recoverpk)
            for i in range(n):
                print(f"  {self.varname[i]:10}: {recoverpk[i]:10.4g} {self.unit[i]}")
            f = self.minimize_func(pk)
            print(f"    f={f:12.6g}")

            if self.get('fmin_list', None) is not None:
                self.iter_list.append(self.iter + 1)
                self.fmin_list.append(f)

        if fplot and self.iter % self.plot_interval == 0:
            ycal_list = self.cal_ylist(recoverpk)

            x_list  = self.__dict__.get('x_list', None)
            xlabels = self.__dict__.get('xlabels', None)
            nx = len(x_list)
            if nx == 1 or type(self.ix_plot) is int:
                if type(self.ix_plot) is int and self.ix_plot < nx:
                    ix_plot = self.ix_plot
                else:
                    ix_plot = 0
                for i in range(len(self.y_list)):
                    axis = self.data_axes[i]
                    self.fit_data_list[i][0].set_data(x_list[ix_plot], ycal_list[i])
            else:
                for i in range(len(self.y_list)):
                    axis = self.data_axes[i]
                    self.fit_data_list[i][0].set_data(range(len(ycal_list[i])), ycal_list[i])
#                    self.fit_data_list[i][0].set_data(self.index, ycal_list[i])

            if self.get('fmin_list', None) is not None:
                self.error_data[0].set_data(self.iter_list, self.fmin_list)
                self.error_axis.set_xlim([min(self.iter_list), max(self.iter_list)])
                self.error_axis.set_ylim([min(self.fmin_list), max(self.fmin_list)])

            plt.tight_layout()
            self.plt.subplots_adjust(top = self.plot_region[0], bottom = self.plot_region[1])
            plt.pause(0.0001)

        self.iter += 1

        return True

    def button_click(self, e):
        self.stop_flag = True

    def initial_plot(self, data_axes, error_axis = None, 
                        label_input = None,
                        yini_list = None, label_ini = None,
                        label_fit = None,
                        ix_plot = '', x_scale = '',
                        label_error = None, fmin = None,
                        fontsize = 16, legend_fontsize = None,
                        plt = None, fig = None, use_pause = 0.00001, 
                        savefig_path = None,
                        button_region = [0.15, 0.95, 0.10, 0.03], plot_region = [0.92, 0.15]):
#                        button_region = [0.15, 0.95, 0.10, 0.03], plot_region = [0.90, 0.10]):
        fplot = hasattr(self, 'fplot') and self.fplot
        if not fplot:
            return

        if label_input is None:
            label_input = 'input'
        if label_ini is None:
            label_ini = 'initial'
        if label_fit is None:
            label_fit = 'fit'
        if label_error is None:
            label_error = 'error'
        if legend_fontsize is None:
            legend_fontsize = fontsize

        self.plt = plt
        self.fig = fig
        self.ix_plot = ix_plot
        self.x_scale = x_scale
        self.use_pause     = use_pause
        self.button_region = button_region
        self.plot_region   = plot_region

        self.yini_list = yini_list
        self.error_axis = error_axis
        self.data_axes = data_axes
        for axis in data_axes:
            axis.tick_params(labelsize = fontsize)

        nx    = len(self.x_list)
        index = range(len(self.x_list[0]))
        if nx == 1 or type(ix_plot) is int:
            if type(ix_plot) is int and ix_plot < nx:
                pass
            else:
                ix_plot = 0
            self.fit_data_list = []
            for i in range(len(self.y_list)):
                axis = data_axes[i]
                self.input_data   = axis.plot(self.x_list[ix_plot], self.y_list[i], label = self.ylabels[i], linestyle = '',  marker = 'o', markersize = 5.0)
                self.initial_data = axis.plot(self.x_list[ix_plot], yini_list[i],   label = label_ini,       linestyle = '-', linewidth = 0.5, color = 'red')
                fit_data = axis.plot([], [], label = label_fit, linestyle = '-', linewidth = 0.5, color = 'blue')
                self.fit_data_list.append(fit_data)
            axis.set_xlabel(self.xlabels[ix_plot], fontsize = fontsize)
            axis.set_ylabel(self.ylabels[0], fontsize = fontsize)
            if self.x_scale == 'log':
                axis.set_xscale('log')
            if self.y_convert == 'log':
                axis.set_yscale('log')
            axis.legend(fontsize = legend_fontsize)
            self.plot_event_input_data   = self.plot_event.add_data({"label": label_input, "plot_type": "2D", "axis": axis, 
                                        "data": self.input_data})
            self.plot_event_initial_data = self.plot_event.add_data({"label": label_ini,   "plot_type": "2D", "axis": axis, 
                                        "data": self.initial_data})
            self.plot_event_fit_data     = self.plot_event.add_data({"label": label_fit,   "plot_type": "2D", "axis": axis, 
                                        "data": fit_data})
        else:
            self.fit_data_list = []
            for i in range(len(self.y_list)):
                axis = data_axes[i]
                self.input_data   = axis.plot(index, self.y_list[i], label = self.ylabels[i], linestyle = '',  marker = 'o', markersize = 5.0)
                self.initial_data = axis.plot(index, yini_list[i],   label = label_ini,   linestyle = '-', linewidth = 0.5, color = 'red')
                fit_data     = axis.plot([], [], label = label_fit, linestyle = '-', linewidth = 0.5, color = 'blue')
                self.fit_data_list.append(fit_data)
            axis.set_xlabel('index', fontsize = fontsize)
            axis.set_ylabel(self.ylabels[0], fontsize = fontsize)
            if self.y_convert == 'log':
                axis.set_yscale('log')
            axis.legend(fontsize = legend_fontsize)
            self.plot_event_input_data   = self.plot_event.add_data({"label": label_input, "plot_type": "2D", "axis": axis, 
                                        "data": self.input_data,   "x_list": self.x_list, "xlabels": self.xlabels})
            self.plot_event_initial_data = self.plot_event.add_data({"label": label_ini,   "plot_type": "2D", "axis": axis, 
                                        "data": self.initial_data, "x_list": self.x_list, "xlabels": self.xlabels})
            self.plot_event_fit_data     = self.plot_event.add_data({"label": label_fit,   "plot_type": "2D", "axis": axis, 
                                        "data": fit_data,     "x_list": self.x_list, "xlabels": self.xlabels})

        self.view_xlim_data_list = []
        self.view_ylim_data_list = []
        for i in range(len(self.y_list)):
            axis = data_axes[i]
            self.view_xlim_data_list.append(axis.get_xlim())
            self.view_ylim_data_list.append(axis.get_ylim())

        if self.fig is None:
            print("")
            print("Warning: figure object is not given.")
            print("    Mouse click event is not bound")
            print("")
        else:
            self.plot_event.register_event(self.fig, event = "button_press_event", 
                        callback = lambda event: self.plot_event.onclick(event))

        if error_axis:
            error_axis.tick_params(labelsize = fontsize)
            if fmin is None:
                fmin = fit.minimize_func(fit.pk)
            self.iter_list = [0]
            self.fmin_list = [fmin]

            self.error_data = error_axis.plot(self.iter_list, self.fmin_list, label = label_error, 
                        linestyle = '-', linewidth = 0.5, color = 'black',
                        marker = 'o', markersize = 5.0, markerfacecolor = 'black', markeredgecolor = 'black')
            error_axis.set_xlabel('iteration', fontsize = fontsize)
            error_axis.set_ylabel('error', fontsize = fontsize)
            error_axis.set_yscale('log')
            error_axis.legend(fontsize = legend_fontsize)

            self.plot_event_error_data = self.plot_event.add_data(
                        {"label": label_error, "plot_type": "2D", "axis": self.error_axis, "data": self.error_data})

        self.plt.tight_layout()
        self.plt.subplots_adjust(top = plot_region[0], bottom = plot_region[1])

        self.button_region = button_region
        self.ax_button = plt.axes(button_region)
        self.stop_button = wg.Button(self.ax_button, 'stop', color = '#f8e58c', hovercolor = '#38b48b')
        self.stop_button.on_clicked(self.button_click)

        if savefig_path is not None and savefig_path != '':
            plt.savefig(savefig_path)

        if use_pause is not None:
            self.plt.pause(use_pause)
        else:
            self.plt.show()
        self.window   = plt.get_current_fig_manager().window

    def finalize_plot(self, yfin_list, iter = None, fmin = None, 
                savefig_path = 'final.png', use_pause = 0.00001, button_region = None, plot_region = None):
        fplot = hasattr(self, 'fplot') and self.fplot
        if not fplot:
            return

        if button_region is None:
            button_region = self.button_region
        if plot_region is None:
            plot_region = self.plot_region

        x_list  = self.__dict__.get('x_list', None)
        xlabels = self.__dict__.get('xlabels', None)
        nx = len(x_list)
        if nx == 1 or type(self.ix_plot) is int:
            if type(self.ix_plot) is int and self.ix_plot < nx:
                ix_plot = self.ix_plot
            else:
                ix_plot = 0
            for i in range(len(self.y_list)):
                self.fit_data_list[i][0].set_data(x_list[ix_plot], yfin_list[i])
                self.data_axes[ix_plot].set_xlim([min(self.x_list[0]), max(self.x_list[0])])
                self.data_axes[ix_plot].set_ylim([min([self.view_ylim_data_list[i][0], min(yfin_list[i])]), 
                                 max([self.view_ylim_data_list[i][1], max(yfin_list[i])])])
        else:
            self.index = list(range(len(yfin_list[0])))
            for i in range(len(self.y_list)):
                self.fit_data_list[0][0].set_data(self.index, yfin_list[i])
                self.data_axes[0].set_xlim([min(self.index), max(self.index)])
                self.data_axes[0].set_ylim([min([self.view_ylim_data_list[i][0], min(yfin_list[i])]), 
                                 max([self.view_ylim_data_list[i][1], max(yfin_list[i])])])
        if self.y_convert == 'log':
            self.data_axes[0].set_yscale('log')

        if iter is None:
            iter = self.iter
        if fmin is None:
            fmin = self.ffin
            self.iter_list.append(iter)
            self.fmin_list.append(ffin)

        if self.get('error_data', None) is not None:
            self.error_data[0].set_data(self.iter_list, self.fmin_list)
            self.error_axis.set_xlim([min(self.iter_list), max(self.iter_list)])
            self.error_axis.set_ylim([min(self.fmin_list), max(self.fmin_list)])

        self.stop_button.label.set_text('finished')
#        self.ax_button = plt.axes(self.button_region)
#        self.stop_button = wg.Button(self.ax_button, 'finished', color = '#f8e58c', hovercolor = '#38b48b')
#        self.stop_button.on_clicked(self.button_click)

        plt.tight_layout()
        self.plt.subplots_adjust(top = plot_region[0], bottom = plot_region[1])

        if savefig_path is not None and savefig_path != '':
            plt.savefig(savefig_path)

        if use_pause is not None:
            self.plt.pause(use_pause)
        else:
            self.plt.show()

