import os
import numpy as np
from numpy import sin, cos, tan, pi, exp, log, sqrt
from scipy.optimize import minimize
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib.widgets as wg


from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg, pconv_by_type
#analyze_varstr, save_csv
from tklib.tkutils import is_exist, is_file, is_dir
from tklib.tkinifile import tkIniFile
from tklib.tkparams import tkParams
from tklib.tkgraphic.tkplotevent import tkPlotEvent

from tklib.tkvariousdata import tkVariousData
#from tklib.tksci.tkmatrix import make_matrix1
from tklib.tksci.tksci import pi, h, hbar,c, e, kB, me, Ea_Arrhenius, log10
from tklib.tksci.tkoptimize import mlsq_general
from tklib.tksci.tkFit_object import tkFit_object


'''
Fitting and interactive matplotlb helper library
'''


class tkFit_m(tkFit_object):
    def __init__(self, parameter_file = None, app = None, 
                    tol = 1.0e-5, nmaxiter = 100,
                    print_interval = 1, plot_interval = 1, **args):
        super(tkParams, self).__init__(**args)
#        super(tkParams, self).__init__(parameter_file, app, args)
#        self.update(**args)

        self.plot_event = tkPlotEvent(plt)
        self.varname  = []
        self.unit     = []
        self.pk       = []
        self.optid    = []
        self.kmin     = []
        self.kmax     = []
        self.kpenalty = []
        
        self.tol = tol
        self.print_interval = print_interval
        self.nmaxiter       = nmaxiter
        self.plot_interval  = plot_interval
        self.stop_flag = False

        self.x_list = None
        self.y      = None
        self.func   = None
        self.iter   = 0
        self.f      = None


    def read_data(self, infile, xlabels = [0], ylabel = 1, xmins = [None], xmaxs = [None], usage = None):
        print("")
        print(f"Read [{infile}]")
        datafile = tkVariousData(infile)
        labels, datalist = datafile.Read_minimum_matrix(close_fp = True, force_numeric = False, usage = usage)
        
        _xlabels = []
        xins     = []
        for l in xlabels:
            xlabel, xin = datafile.FindDataArray(l, flag = 'i')
            print("x l=", l, xlabel)
            _xlabels.append(xlabel)
            xins.append(xin)
#        xlabel, xin = datafile.FindDataArray(xlabel, flag = 'i')
        print("ylabel=", ylabel)
        ylabel, yin = datafile.FindDataArray(ylabel, flag = 'i')

        self.datafile = datafile
        self.labels   = labels
        self.datalist = datalist
        self.xlabels  = _xlabels
        self.ylabel   = ylabel

        self.nx        = len(xlabels)
        self.ndata_all = len(xin)
#        print("ndata_all=", self.ndata_all)
        self.x_list = []
        for il in range(self.nx):
            self.x_list.append([])
        self.y = []
        self.included_index = []
        for i in range(self.ndata_all):
            is_included = True
            for il in range(len(_xlabels)):
                if xmins[il] is not None and xmins[il] > xins[il][i]:
                    is_included = False
                    continue
                if xmaxs[il] is not None and xmaxs[il] < xins[il][i]:
                    is_included = False
                    continue

                if is_included:
                    self.x_list[il].append(xins[il][i])

            if is_included:
                self.y.append(yin[i])
                self.included_index.append(i)

        self.ndata = len(self.x_list[0])
        self.index = range(self.ndata)

    def cal_ylist(self, pk, x_list = None):
        if x_list is None:
            x_list = self.x_list

#        print("205 func: len(pk)=", len(pk))
        nvars = len(x_list)
        ndata = len(x_list[0])
        y_list = []
        for i in range(ndata):
            xs = [x_list[j][i] for j in range(nvars)]
            y_list.append(self.func(xs, pk))

        return y_list

    def minimize_func(self, pk, x_list = None, y = None):
        if x_list is None:
            x_list = self.x_list
        if y is None:
            y = self.y

        recoverpk = self.recover_parameters(pk, set_member = False)

        f = 0.0
        nx    = len(x_list)
        ndata = len(x_list[0])
        w     = self.get('w', None)
        for i in range(ndata):
            xs = [x_list[j][i] for j in range(nx)]
            ycal = self.func(xs, recoverpk)
            d = ycal - y[i]
            if w is None:
                f += d * d
            else:
                f += w[i] * d * d
#            print("85 i=", i, self.x[i], self.y[i], ycal)

        kp   = self.kpenalty
        kmin = self.kmin
        kmax = self.kmax
        nkp = len(kp)
        p_tot = 0.0
        for i in range(nkp):
            if recoverpk[i] < kmin[i]:
                d = recoverpk[i] - kmin[i]
                p = kp[i] * d * d
                p_tot += p
                print(f"**Warning: [{self.varname[i]}] is smaller than [{kmin[i]}]. Add penalty [{p:10.3g}] to f")
            elif recoverpk[i] > kmax[i]:
                d = recoverpk[i] - kmax[i]
                p = kp[i] * d * d
                p_tot += p
                print(f"**Warning: [{self.varname[i]}] is larger than [{kmax[i]}]. Add penalty [{p:10.3g}] to f")

        f = sqrt(f / ndata) + p_tot

        return f

    def minimize(self, method = None, jac = None, tol = None, nmaxiter = None):
        if method is None:
            method = self.method
        if jac is None:
            jac = self.get('jac', '3-point')
        if tol is None:
            tol = self.tol
        if nmaxiter is None:
            nmaxiter = self.nmaxiter

        print("")
        print("Optimizing parameters:")
        optpk = self.extract_parameters()
        print("   optpk=", optpk)
        print("Minimize:")
        res = minimize(self.minimize_func, optpk, 
                    method = method, jac = jac, 
                    callback = self.callback,
                    tol = tol, options = {'maxiter': nmaxiter, "disp": True})

        self.finalpk = self.recover_parameters(res.x)
        self.iter = res.nit
        if self.get('fun', None) is not None:
            self.ffin = res.fun
        else:
#            self.ffin = self.minimize_func(self.finalpk)
            self.ffin = self.minimize_func(res.x)
            
        return self.finalpk, res.fun, res.success, res

    def callback(self, pk):
        if self.stop_flag:
            return False

        w = self.plt.get_current_fig_manager().window
        if w != self.window:
            self.stop_flag = True
            return False

        recoverpk = self.recover_parameters(pk, set_member = False)

        if self.iter % self.print_interval == 0:
            print(f"iter: {self.iter}")
            n = len(recoverpk)
            for i in range(n):
                print(f"  {self.varname[i]:10}: {recoverpk[i]:10.4g} {self.unit[i]}")
            f = self.minimize_func(pk)
            print(f"    f={f:12.6g}")

            if self.get('fmin_list', None) is not None:
                self.iter_list.append(self.iter + 1)
                self.fmin_list.append(f)

        if self.iter % self.plot_interval == 0:
            ycal = self.cal_ylist(recoverpk)

            x_list  = self.__dict__.get('x_list', None)
            xlabels = self.__dict__.get('xlabels', None)
            nx = len(x_list)
            if nx == 1:
                self.fit_data[0].set_data(x_list[0], ycal)
                self.data_axis.set_ylim([min([self.view_ylim_data[0], min(ycal)]), 
                                     max([self.view_ylim_data[1], max(ycal)])])
            else:
                self.fit_data[0].set_data(self.index, ycal)
                self.data_axis.set_ylim([min([self.view_ylim_data[0], min(ycal)]), 
                                     max([self.view_ylim_data[1], max(ycal)])])

            if self.get('fmin_list', None) is not None:
                self.error_data[0].set_data(self.iter_list, self.fmin_list)
                self.error_axis.set_xlim([min(self.iter_list), max(self.iter_list)])
                self.error_axis.set_ylim([min(self.fmin_list), max(self.fmin_list)])

            plt.tight_layout()
            self.plt.subplots_adjust(top = self.plot_region[0], bottom = self.plot_region[1])
            plt.pause(0.0001)

        self.iter += 1

        return True

    def button_click(self, e):
        self.stop_flag = True

    def initial_plot(self, data_axis, error_axis = None, 
                        label_input = None,
                        yini = None, label_ini = None,
                        label_fit = None,
                        label_error = None, fmin = None,
                        fontsize = 16, fig = None, use_pause = 0.00001, 
                        button_region = [0.15, 0.95, 0.10, 0.03], plot_region = [0.92, 0.15]):
#                        button_region = [0.15, 0.95, 0.10, 0.03], plot_region = [0.90, 0.10]):
        if label_input is None:
            label_input = 'input'
        if label_ini is None:
            label_ini = 'initial'
        if label_fit is None:
            label_fit = 'fit'
        if label_error is None:
            label_error = 'error'

        self.plt = plt
        self.fig = fig
        self.use_pause     = use_pause
        self.button_region = button_region
        self.plot_region   = plot_region

        self.data_axis = data_axis
        axis = data_axis
        axis.tick_params(labelsize = fontsize)

#        x_list  = self.__dict__.get('x_list', None)
#        xlabels = self.__dict__.get('xlabels', None)
        nx    = self.nx
        index = range(self.ndata)
        if nx == 1:
            self.input_data   = axis.plot(self.x_list[0], self.y, label = self.xlabels[0], linestyle = '',  marker = 'o', markersize = 5.0)
            self.initial_data = axis.plot(self.x_list[0], yini,   label = label_ini,       linestyle = '-', linewidth = 0.5, color = 'red')
            self.fit_data     = axis.plot([], [], label = label_fit, linestyle = '-', linewidth = 0.5, color = 'blue')
            self.plot_event_input_data   = self.plot_event.add_data({"label": label_input, "plot_type": "2D", "axis": axis, 
                                            "data": self.input_data})
            self.plot_event_initial_data = self.plot_event.add_data({"label": label_ini,   "plot_type": "2D", "axis": axis, 
                                            "data": self.initial_data})
            self.plot_event_fit_data     = self.plot_event.add_data({"label": label_fit,   "plot_type": "2D", "axis": axis, 
                                            "data": self.fit_data})
            axis.set_xlabel(self.xlabels[0], fontsize = fontsize)
            axis.set_ylabel(self.ylabel,     fontsize = fontsize)
        else:
            self.input_data   = axis.plot(index, self.y, label = label_input, linestyle = '',  marker = 'o', markersize = 5.0)
            self.initial_data = axis.plot(index, yini,   label = label_ini,   linestyle = '-', linewidth = 0.5, color = 'red')
            self.fit_data     = axis.plot([], [], label = label_fit, linestyle = '-', linewidth = 0.5, color = 'blue')
            self.plot_event_input_data   = self.plot_event.add_data({"label": label_input, "plot_type": "2D", "axis": axis, 
                                            "data": self.input_data,   "x_list": self.x_list, "xlabels": self.xlabels})
            self.plot_event_initial_data = self.plot_event.add_data({"label": label_ini,   "plot_type": "2D", "axis": axis, 
                                            "data": self.initial_data, "x_list": self.x_list, "xlabels": self.xlabels})
            self.plot_event_fit_data     = self.plot_event.add_data({"label": label_fit,   "plot_type": "2D", "axis": axis, 
                                            "data": self.fit_data,     "x_list": self.x_list, "xlabels": self.xlabels})
            axis.set_xlabel('index', fontsize = fontsize)
            axis.set_ylabel(self.ylabel,  fontsize = fontsize)

        self.view_xlim_data = axis.get_xlim()
        self.view_ylim_data = axis.get_ylim()
        axis.legend(fontsize = fontsize)

        if self.fig is None:
            print("")
            print("Warning: figure object is not given.")
            print("    Mouse click event is not bound")
            print("")
        else:
            self.plot_event.register_event(self.fig, event = "button_press_event", 
                        callback = lambda event: self.plot_event.onclick(event))

        if error_axis:
            error_axis.tick_params(labelsize = fontsize)
            self.error_axis = error_axis
            if fmin is None:
                fmin = fit.minimize_func(fit.pk)
            self.iter_list = [0]
            self.fmin_list = [fmin]

            self.error_data = error_axis.plot(self.iter_list, self.fmin_list, label = label_error, 
                        linestyle = '-', linewidth = 0.5, color = 'black',
                        marker = 'o', markersize = 5.0, markerfacecolor = 'black', markeredgecolor = 'black')
            error_axis.set_xlabel('iteration', fontsize = fontsize)
            error_axis.set_ylabel('error', fontsize = fontsize)
            error_axis.set_yscale('log')
            error_axis.legend(fontsize = fontsize)

            self.plot_event_error_data = self.plot_event.add_data(
                        {"label": label_error, "plot_type": "2D", "axis": self.error_axis, "data": self.error_data})

        self.plt.tight_layout()
        self.plt.subplots_adjust(top = plot_region[0], bottom = plot_region[1])

        self.button_region = button_region
        self.ax_button = plt.axes(button_region)
        self.stop_button = wg.Button(self.ax_button, 'stop', color = '#f8e58c', hovercolor = '#38b48b')
        self.stop_button.on_clicked(self.button_click)

        if use_pause is not None:
            self.plt.pause(use_pause)
        else:
            self.plt.show()
        self.window   = plt.get_current_fig_manager().window

        self.axis = axis
        self.yini = yini

    def finalize_plot(self, yfin, iter = None, fmin = None, use_pause = 0.00001, button_region = None, plot_region = None):
        if button_region is None:
            button_region = self.button_region
        if plot_region is None:
            plot_region = self.plot_region

        x_list  = self.__dict__.get('x_list', None)
        xlabels = self.__dict__.get('xlabels', None)
        nx = len(x_list)
        if nx == 1:
            self.fit_data[0].set_data(x_list[0], yfin)
            self.data_axis.set_xlim([min(self.x_list[0]), max(self.x_list[0])])
            self.data_axis.set_ylim([min([self.view_ylim_data[0], min(yfin)]), 
                                 max([self.view_ylim_data[1], max(yfin)])])
        else:
            self.fit_data[0].set_data(self.index, yfin)
            self.data_axis.set_xlim([min(self.index), max(self.index)])
            self.data_axis.set_ylim([min([self.view_ylim_data[0], min(yfin)]), 
                                 max([self.view_ylim_data[1], max(yfin)])])

        if iter is None:
            iter = self.iter
        if fmin is None:
            fmin = self.ffin
            self.iter_list.append(iter)
            self.fmin_list.append(ffin)

        if self.get('error_data', None) is not None:
            self.error_data[0].set_data(self.iter_list, self.fmin_list)
            self.error_axis.set_xlim([min(self.iter_list), max(self.iter_list)])
            self.error_axis.set_ylim([min(self.fmin_list), max(self.fmin_list)])

        self.stop_button.label.set_text('finished')
#        self.ax_button = plt.axes(self.button_region)
#        self.stop_button = wg.Button(self.ax_button, 'finished', color = '#f8e58c', hovercolor = '#38b48b')
#        self.stop_button.on_clicked(self.button_click)

        plt.tight_layout()
        self.plt.subplots_adjust(top = plot_region[0], bottom = plot_region[1])

        if use_pause is not None:
            self.plt.pause(use_pause)
        else:
            self.plt.show()

