import os
import numpy as np
from numpy import sin, cos, tan, pi, exp, log, sqrt
from scipy.optimize import minimize
import csv
import pandas as pd
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from matplotlib import pyplot as plt
import matplotlib.widgets as wg


from tklib.tkutils import terminate, pint, pfloat, pconv, getarg, getintarg, getfloatarg, pconv_by_type, conv_float
from tklib.tkutils import is_numeric, is_exist, is_file, is_dir, get_ext
from tklib.tkinifile import tkIniFile
from tklib.tkexcel import tkExcel
from tklib.tkparams import tkParams
from tklib.tkgraphic.tkplotevent import tkPlotEvent

from tklib.tkvariousdata import tkVariousData
#from tklib.tksci.tkmatrix import make_matrix1
from tklib.tksci.tksci import pi, h, hbar,c, e, kB, me, Ea_Arrhenius, log10
from tklib.tksci.tkoptimize import mlsq_general


'''
Fitting and interactive matplotlib helper object library
'''


def read_csv_to_dict(path, delimiter = ',', print_level = 1):
    try:
        f = open(path)
    except Exception as e:
        if print_level:
            print(f"Warning in tkFit_object::read_csv_to_dict(): Can not read [{path}]")
            return None, None

    _dict = {}
    reader = csv.reader(f)
    labels = next(reader)
    for l in labels:
        _dict[l] = []

    for row in reader:
        for i, l in enumerate(labels):
            _dict[l].append(pconv(row[i]))

    return labels, _dict

def read_fit_config_from_file(vars, path, 
            labels = ["varname", "unit", "pk_scale", "optid", "linid", "x0", "dx", "kmin", "kmax", "kpenalty"], 
            defval = None):
    if not os.path.exists(path): return False

    ext = get_ext(path).lower()
    if ext == '.xlsx':
        _dict, sheetnames, keys, datalist = tkExcel().read_sheet_dict(path)
        if _dict is None: return False

        ndata = len(_dict[labels[0]])
        for l in labels:
            _list = _dict.get(l, defval)
            if _list is None:
                print(f"\nError in tkFit_object.read_fit_config_from_file(): Can not find a key [{l}] in [{path}]")
                if l == 'x0': 
                    print(f"  You would forget to change the key 'pk' to 'x0'?\n")
                    return False
                elif l == 'linid': 
                    print(f"  The key for linear LSQ ID [linid] would not be used. Skip.\n")
                    continue

            _list = ['' if v is None else 1 if v is True else 0 if v is False else v for v in _list]
            setattr(vars, l, _list)
    elif ext == '.csv':
        labels, _dict = read_csv_to_dict(path)
        if _dict is None: return False
        
        ndata = len(_dict[labels[0]])
        for l in labels:
            _list = _dict.get(l, defval)
            _list = ['' if v is None else 1 if v is True else 0 if v is False else v for v in _list]
            setattr(vars, l, _list)
#            print("vars=", vars, l, getattr(vars, l))
    else:
        print(f"Error in tkFit_object::read_fit_config_from_file(): Invalid ext [{ext}] for [{path}]")
        return False

    for fvar in ["x0", "dx", "kmin", "kmax", "kpenalty"]:
        vlist = getattr(vars, fvar, None)
        if vlist is None: continue
        vlist = [pfloat(v) for v in vlist]

    for ivar in ["optid", "linid"]:
        vlist = getattr(vars, ivar, None)
        if vlist is None: continue
        vlist = [pint(v) for v in vlist]

#    vars.print_parameters()

    return True

def save_fit_config(vars, outfile, 
            labels = ["varname", "unit", "pk_scale", "optid", "linid", "x0", "dx", "kmin", "kmax", "kpenalty"],
            section = "Parameters", print_level = 1):
    if print_level and vars.hasattr('print_variables'): vars.print_variables()

    ext = get_ext(outfile).lower()
    if ext in ['.csv', '.xlsx']:
        _labels = []
        _data_list = []
        for i, l in enumerate(labels):
            val = getattr(vars, l, None)
            if val is None:
                continue
            
            _labels.append(l)
            _data_list.append(val)

        ret = tkVariousData().save(outfile, data_list = _data_list, labels = _labels, print_level = print_level)
        return ret

    print(f"Error in tkFit_object.save_fit_config(): Invalid ext [{ext}]")
    exit()
#    return self.save_parameters(outfile, section = section, keys = self.varname)

class tkFit_object(tkParams):
    def __init__(self, parameter_file = None, app = None, 
                    tol = 1.0e-5, nmaxiter = 100,
                    print_interval = 1, plot_interval = 1, **args):
        super().__init__(**args)
#        super(tkParams, self).__init__(parameter_file, app, args)
#        self.update(**args)

        self._xlabel = None
        self._ylabel = None
        
        self.varname  = []
        self.unit     = []
        self.pk       = []
        self.optid    = []
        self.linid    = []
        self.kpenalty = []
        self.kmin     = []
        self.kmax     = []
        
        self.tol = tol
        self.print_interval = print_interval
        self.nmaxiter       = nmaxiter
        self.plot_interval  = plot_interval
        self.stop_flag = False

        self.x    = None
        self.y    = None
        self.func = None
        self.iter = 0
        self.f    = None

    def get_value(self, target_key, keys, values, def_val = None):
        for i in range(len(keys)):
            key = keys[i]
            if key == target_key:
                return values[i]
        return def_val

    def read_parameters_to_params(self, prmfile, keys, params, heading = None, section = None):
        if heading is not None: print(heading)

#        print("keys=", keys)
        p = tkParams()
        p.read_parameters(prmfile, section = section)
        for key in keys:
            val = p.get(key, None)
            if val is None:
                continue

            aa = val.split(':')
            if len(aa) == 0:
                continue

            v = params.get(key, None)
            if v is None:
                setattr(params, key, aa[0])
            elif type(v) == int:
                setattr(params, key, pint(aa[0], defval = aa[0]))
            elif type(v) == float:
                setattr(params, key, pfloat(aa[0], defval = aa[0]))
            else:
                setattr(params, key, pconv(aa[0], defval = aa[0]))

    def read_parameters(self, prmfile, heading = None, section = None, 
                        params = None,
                        keys = None, values = None, optid = None, 
                        kmin = None, kmax = None, kpenalty = None,
                        dx = None,
                        ignore_keys = []):

        if params is not None:
            return self.read_parameters_to_params(prmfile, keys = keys, params = params, heading = heading, section = section)

        ext = get_ext(prmfile).lower()
        if ext in [".csv", ".xlsx"]:
            ret = read_fit_config_from_file(self, prmfile)
            return ret

        if heading is not None: print(heading)
        if keys is None: keys = self.varname

        nvars = len(keys)
        if values is None:
            if self.pk:
                values = self.pk
            else:
                values = [0.0] * nvars
        if optid is None:
            if self.optid:
                optid = self.optid
            else:
                values = [1] * nvars
        if kmin is None:
            if self.kmin:
                kmin = self.kmin
            else:
                kmin = [None] * nvars
        if kmax is None:
            if self.kmax:
                kmax = self.kmax
            else:
                kmax = [None] * nvars
        if kpenalty is None:
            if self.kpenalty:
                kpenalty = self.kpenalty
            else:
                kpenalty = [1.0] * nvars
        if dx is None:
            if self.dx:
                dx = self.dx

        p = tkParams()
        p.read_parameters(prmfile, section = section)
#        p.print_parameters(heading = "p")

        for i in range(len(keys)):
            key = keys[i]
            if key in ignore_keys:
                continue

            v = p.__dict__.get(key, None)
            if v is None:
                continue

#            print("type=", i, keys[i], type(values[i]))
            t = type(values[i])
            if ':' in v:
                aa = v.split(':')
                n = len(aa)
#                print("aa=", key, v, aa)
#                values[i] = pfloat(aa[0], defval = values[i])
                values[i] = pconv_by_type(aa[0], type = type(values[i]), defval = values[i], strict = True)
                if n >= 2:
                    optid[i] = pint(aa[1])
                if n >= 3 and len(kmin) >= i+1:
                    kmin[i] = pfloat(aa[2])
                if n >= 4 and len(kmax) >= i+1:
                    kmax[i] = pfloat(aa[3])
                if n >= 5 and len(kpenalty) >= i+1:
                    kpenalty[i] = pfloat(aa[4])
                if n >= 6 and dx:
                    dx[i] = pfloat(aa[5])
#                print("v=", i, values[i], optid[i])
            else:
#                values[i] = pfloat(v, defval = values[i])
                values[i] = pconv_by_type(v, type = type(values[i]), defval = values[i], strict = True)
#                print("v=", i, values[i])

    def save_parameters_by_keys(self, prmfile, heading = None, section = None, keys = None, print_level = 1):
        ini = tkIniFile(path = prmfile)
        for key in keys:
            v = self.get(key, None)
            if v is not None:
                ini.write_string(section = section, key = key, value = v, is_print = False)

    def save_parameters(self, prmfile, heading = None, section = None, 
                        params = None,
                        keys = None, values = None, optid = None, linid = None, pk_scale = None,
                        kmin = None, kmax = None, kpenalty = None,
                        dx = None,
                        labels = ["varname", "unit", "pk_scale", "optid", "linid", "pk", "dx", "kmin", "kmax", "kpenalty"],
                        vars = None, print_level = 1):
        if params is not None:
            return self.save_parameters_by_keys(prmfile, heading = heading, section = section, keys = keys, print_level = print_level)

        if heading is not None: print(heading)

        if keys is None         : keys = self.varname
        if values is None       : values = self.pk
        if optid is None        : optid = self.optid
        if linid is None        : linid = self.linid
        if pk_scale is None     : pk_scale = self.pk_scale
        if kmin is None         : kmin = self.kmin
        if kmax is None         : kmax = self.kmax
        if kpenalty is None     : kpenalty = self.kpenalty
        if dx is None           : dx = self.dx
        
        ext = get_ext(prmfile).lower()
        if ext in [".csv", ".xlsx"]:
            _varname  = self.varname
            _pk       = self.pk
            _optid    = self.optid
            _linid    = self.linid
            _pk_scale = self.pk_scale
            _kmin     = self.kmin
            _kmax     = self.kmax
            _kpenalty = self.kpenalty
            _dx       = self.dx
            self.varname  = keys

            if vars is None: self.pk = values
            else: 
                self.pk = []
                for l in keys:
                    self.pk.append(getattr(vars, l, None))

            self.optid    = optid
            self.lineid    = linid
            self.pk       = values
            self.pk_scale = pk_scale
            self.kmin     = kmin
            self.kmax     = kmax
            self.kpenalty = kpenalty
            self.dx       = dx
            ret = save_fit_config(self, prmfile, labels = labels, print_level = print_level)
            self.varname  = _varname
            self.pk       = _pk
            self.optid    = _optid
            self.linid    = _linid
            self.pk_scale = _pk_scale
            self.kmin     = _kmin
            self.kmax     = _kmax
            self.kpenalty = _kpenalty
            self.dx       = _dx

            return ret

        ini = tkIniFile(path = prmfile)
    
        for i in range(len(keys)):
            key = keys[i]
            v = conv_float(values[i])
            if len(kpenalty) > 0:
                _kmin = conv_float(kmin[i])
                _kmax = conv_float(kmax[i])
                _kpenalty = conv_float(kpenalty[i])
                if dx:
                    _dx = conv_float(dx[i]) 
                    value = f"{v}:{optid[i]}:{_kmin}:{_kmax}:{_kpenalty}:{_dx}"
                else:
                    value = f"{v}:{optid[i]}:{_kmin}:{_kmax}:{_kpenalty}"
            else:
                value = f"{v}:{optid[i]}"
            ini.write_string(section = section, key = key, value = value, is_print = False)

    def extract_parameters(self, pk = None, optid = None, target = None):
        if pk is None    : pk = self.pk
        if optid is None : optid = self.optid
        if target is None: target = pk

        optpk = []
        for i in range(len(pk)):
            if optid[i]: optpk.append(target[i])
        return optpk

    def recover_parameters(self, optpk, allpk = None, optid = None, set_member = True):
        if allpk is None: allpk = self.pk
        if optid is None: optid = self.optid

        pk = []
        c = 0
        for i in range(len(allpk)):
            if optid[i] == 1:
                pk.append(optpk[c])
                c += 1
            else:
                pk.append(allpk[i])
        if set_member:
            self.pk = pk

        return pk

    def disconnect_backward_data(self, xlist, ylist, disconnect = True):
        if not disconnect: return xlist, ylist

        _xlist = []
        _ylist = []
        x_prev = None
        for x, y in zip(xlist, ylist):
            if x_prev and x_prev > x:
                _xlist.append(None)
                _ylist.append(None)

            _xlist.append(x)
            _ylist.append(y)
            x_prev = x

        return _xlist, _ylist

    def build_initial_simplex(self, print_level = 1):
        if self.dx is None or len(self.dx) == 0: return []

        simplex = []
        if print_level: print("\nBuild initial simplex")
        
        optpk = self.extract_parameters()
        optdx = self.extract_parameters(self.dx)
        nvar = len(optpk)
        simplex.append(optpk)
        for idx in range(nvar):
            l = []
            for j in range(nvar):
                if j == idx:
                    l.append(optpk[j] + optdx[j])
                else:
                    l.append(optpk[j])
            simplex.append(l)
        
        return simplex

    def build_parameter_list(self, var_list = None, source_obj = None):
        if var_list is None  : var_list = self.varname
        if source_obj is None: source_obj = self

        return [getattr(source_obj, var) for var in var_list]

    def retrieve_parameter_list(self, parameter_list, var_list = None, target = None, print_level = 1):
        if target is None:
            target = self
        if var_list is None:
            var_list = self.varname

        for i, var in enumerate(var_list):
            setattr(target, var, parameter_list[i])

    def configure(self, **kwargs):
        for key, val in kwargs.items():
            setattr(self, key, val)

    def copy_attributes(self, source_obj, var_list):
        for var in var_list:
            val = getattr(source_obj, var, None)
            setattr(self, var, val)

    def print_variables(self, heading = None, varname = None, unit = None, pk = None, pk_scale = None, dx = None, 
                        optid = None, linid = None,
                        kmin = None, kmax = None, kpenalty = None, f = None, fmin = None):
        if heading is not None: print(heading)

        if varname is None  : varname = self.varname
        if unit is None     : unit = self.unit
        if pk is None       : pk = self.pk
        if pk_scale is None : pk_scale = self.pk_scale
        if optid is None    : optid = self.optid
        if linid is None    : linid = self.linid
        if kmin is None     : kmin = self.kmin
        if kmax is None     : kmax = self.kmax
        if kpenalty is None : kpenalty = self.kpenalty

        np = len(kpenalty) if kpenalty else 0
        if np == 0:
            print(f"Warning in tkFit_object.print_variables(): No parameter is given")
            return

        for i in range(len(pk)):
#            if type(pk[i]) is not float and type(pk[i]) is not int:
            if not is_numeric(pk[i]):
                print(f"\nError in tkFit_object.print_variables(): For variable [{varname[i]}]: pk[{i}]={pk[i]} is not numeric (type={type(pk[i])})\n")
                exit()
            if np > 0:
                if linid and len(linid) > i:
                    print(f"   {i:02d}: {varname[i]:>10}={pk[i]:12.6g} {unit[i]:10} (id={optid[i]}) (linear={linid[i]}) "
                        + f"penality: {kpenalty[i]:8.3g} * ({kmin[i]:g} - {kmax[i]:g})")
                else:
                    print(f"   {i:02d}: {varname[i]:>10}={pk[i]:12.6g} {unit[i]:10} (id={optid[i]}) "
                        + f"penality: {kpenalty[i]:8.3g} * ({kmin[i]:g} - {kmax[i]:g})")
            else:
                if linid and len(linid) > i:
                    print(f"   {varname[i]:>10}={pk[i]:12.6g} {unit[i]:10} (id={optid[i]}) (linear={linid[i]})")
                else:
                    print(f"   {varname[i]:>10}={pk[i]:12.6g} {unit[i]:10} (id={optid[i]})")

        if f is not None   : print("  f=", f)
        if fmin is not None: print("  fmin=", fmin)

    def read_datalist(self, infile = None, idata = [],
                xlabel = None, ylabel = None, x_label = 0, y_label = 1, xmin = None, xmax = None, usage = None):
        if xlabel is None: xlabel = x_label
        if ylabel is None: ylabel = y_label                
        self.read_data(infile = infile, xlabel = xlabel, y_label = y_label, xmin = xmin, xmax = xma, usage = usage)

        labels = []
        data_list = []
        for idx in idata:
            label, data = self.datafile.find_data_array(idx, flag = flag, flags = flags)
            labels.append(label)
            data_list.append(data)
        
        return labels, data_list

    @property
    def xlabel(self):
        return self._xlabel

    @xlabel.setter
    def xlabel(self, value):
        self._xlabel = value

    @property
    def x_label(self):
        return self._xlabel

    @x_label.setter
    def x_label(self, value):
        self._xlabel = value

    @property
    def ylabel(self):
        return self._ylabel

    @ylabel.setter
    def ylabel(self, value):
        self._ylabel = value

    @property
    def y_label(self):
        return self._ylabel

    @y_label.setter
    def y_label(self, value):
        self._ylabel = value

    def read_data(self, infile = None, xlabel = None, ylabel = None, x_label = 0, y_label = 1, 
                    idata = None, flag = 'i', flags = '',
                    xmin = None, xmax = None, usage = None):
        if xlabel is None: 
            xlabel = x_label
        else:
            x_label = xlabel
        if ylabel is None: 
            ylabel = y_label
        else:
            y_label = ylabel

        if idata is not None:
            return self.read_datalist(infile = infile, idata = idata, flag = flag, 
                        xlabel = x_label, y_label = y_label, xmin = xmin, xmax = xmax, usage = usage)

        print("")
        print(f"tkFit_object.read_data(): Read [{infile}]")
        datafile = tkVariousData(infile)
        labels, datalist = datafile.Read_minimum_matrix(close_fp = True, force_numeric = False, usage = usage)
        x_label, xin = datafile.find_data_array(xlabel, flag = flag, flags = flags)
        y_label, yin = datafile.find_data_array(ylabel, flag = flag, flags = flags)

        self.datafile = datafile
        self.labels   = labels
        self.datalist = datalist
        self.x_label  = x_label
        self.y_label  = y_label
        
        self.ndata_all = len(xin)
        print("ndata_all=", self.ndata_all)
        self.x = []
        self.y = []
        self.included_index = []
        for i in range(self.ndata_all):
            if xmin is not None and xmin > xin[i]:
                continue
            if xmax is not None and xmax < xin[i]:
                continue
            
            self.x.append(xin[i])
            self.y.append(yin[i])
            self.included_index.append(i)

        self.ndata = len(self.x)
        self.index = range(self.ndata)
        
    def print_data(self, heading = '', yini = None, yfin = None):
        nx    = len(self.x_list)
        ndata = len(self.x_list[0])

        print(f"{'i':>4}:{'sn':<4} ", end = '')
        for j in range(nx):
            print(f"{self.xlabels[j]:>10} ", end = '')
        if yfin is None:
            print(f"{'y(input)':>10} {'y(ini)':>10}")
        else:
            print(f"{'y(input)':>10} {'y(ini)':>10} {'y(fin)':>10}")
        for i in range(ndata):
            print(f"{self.index[i]:>4}:{self.included_index[i]:<4}: ", end = '')
            for j in range(nx):
                print(f"{self.x_list[j][i]:10.4g} ", end = '')
            if yfin is None:
                print(f"{self.y[i]:10.4g} {yini[i]:10.4g}")
            else:
                print(f"{self.y[i]:10.4g} {yini[i]:10.4g} {yfin[i]:10.4g}")

    def to_excel(self, outfile, labels, data_list, template = None):
        return tkVariousData().to_excel(outfile, labels, data_list, template = template)
    
        """
        nx = len(data_list)
        max_ndata = 0
        for d in data_list:
            ndata = len(d)
            if max_ndata < ndata:
                max_ndata = ndata

        for i in range(nx):
            ndata = len(data_list[i])
            if ndata < max_ndata:
                if isinstance(data_list[i], np.ndarray):
#                if type(data_list[i]) is not list:
                    data_list[i] = data_list[i].tolist()

                data_list[i] = data_list[i].copy()
                for j in range(max_ndata - ndata):
                    data_list[i].append(None)

        for i in range(nx):
            dl = data_list[i]
            for j in range(len(dl)):
                try:
                    dl[j] = int(dl[j])
                except:
                    try:
                        dl[j] = float(dl[j])
                    except:
                        pass

        df = pd.DataFrame(np.array(data_list).T, columns = labels)
        df.to_excel(outfile, index = False)

        if not os.path.exists(outfile):
            print("")
            print(f"Error in tkFit.to_excel(): Could not write to [{outfile}].")
            return False

        return True
        """


    def print_scores(self, y1, y2, heading = None):
        if heading is not None:
            print(heading)

        y1_mean = np.mean(y1)
        y2_mean = np.mean(y2)
        y1_var  = np.var(y1)
        y2_var  = np.var(y2)
        y1_std0 = np.std(y1)
        y2_std0 = np.std(y2)
        y1_std1 = np.std(y1, ddof = 1)
        y2_std1 = np.std(y2, ddof = 1)
        MAE  = mean_absolute_error(y1, y2)
        MSE  = mean_squared_error(y1, y2)
        RMSE = sqrt(MSE)
        R2 = r2_score(y1, y2)
        sxx = np.dot(y1, y1)
        syy = np.dot(y2, y2)
        sxy = np.dot(y1, y2)
        rcorr = sxy / sqrt(sxx * syy)

        print(f"  Mean values = <y1> = sum(y1) / n                 : {y1_mean:12.4g}")
        print(f"                <y2> = sum(y2) / n                 : {y2_mean:12.4g}")
        print(f"  Variance    = sum((y1 - <y1>)^2) / n             : {y1_var:12.4g}")
        print(f"              = sum((y2 - <y2>)^2) / n             : {y2_var:12.4g}")
        print(f"  Standard deviation = sqrt(Variance(y1))          : {y1_std0:12.4g}")
        print(f"                       sqrt(Variance(y2))          : {y2_std0:12.4g}")
        print(f"  Sample std  = sqrt(sum(y1 - <y1>)^2 / (n-1)      : {y1_std1:12.4g}")
        print(f"                sqrt(sum(y1 - <y1>)^2 / (n-1)      : {y2_std1:12.4g}")
        print(f"  MAE  (mean absolute error) = sum(|y1 - y2|) / n  : {MAE:12.4g}")
        print(f"  MSE  (mean squared error)  = sum((y1 - y2)^2) / n: {MSE:12.4g}")
        print(f"  RMSE (root MSE)            = sqrt(MSE)           : {RMSE:12.4g}")
        print(f"  R^2  (coefficient of determnation)")
        print(f"           = 1 - sum((y1 - y2)^2 / sum(y1 - <y2>)^2: {R2:12.4g}")
        print(f"  r    (correlation coefficient)                   : {rcorr:12.4g}")

    def button_click(e):
        self.stop_flag = True
        print("\ntkFit_object.button_click(): Stop button was pressed. Terminating...\n")
    
    def initial_plot(self, fig = None, use_pause = 0.00001, 
                        button_region = [0.15, 0.95, 0.10, 0.03], plot_region = [0.92, 0.15],
                        on_clicked = None):
        if on_clicked is None:
            on_clicked = self.button_click

        self.plot_event = tkPlotEvent(plt)
        self.plt = plt
        self.fig = fig
        self.use_pause = use_pause

        if self.fig is None:
            print("")
            print("Warning: figure object is not given.")
            print("    Mouse click event is not bound")
            print("")
        else:
            self.plot_event.register_event(self.fig, event = "button_press_event", 
                        callback = lambda event: self.plot_event.onclick(event))

        self.button_region = button_region
        self.plot_region   = plot_region
        self.button_region = button_region
        self.ax_button     = plt.axes(button_region)
        self.stop_button   = wg.Button(self.ax_button, 'stop', color = '#f8e58c', hovercolor = '#38b48b')
        self.stop_button.on_clicked(on_clicked)

    def initialize_plot(self, fig = None, use_pause = 0.00001, 
                        button_region = [0.15, 0.95, 0.10, 0.03], plot_region = [0.92, 0.15],
                        on_clicked = None):
        return self.initial_plot(fig = fig, use_pause = use_pause, 
                        button_region = button_region, plot_region = plot_region, on_clicked = on_clicked)


    def layout(self, show = False):
#        self.plt.tight_layout()
        self.plt.subplots_adjust(top = self.plot_region[0], bottom = self.plot_region[1])

        if show:
            if use_pause is not None:
                self.plt.pause(use_pause)
            else:
                self.plt.show()

