#import warnings
#warnings.filterwarnings('ignore')
import sys
import copy
import numpy as np
from numpy import sin, cos, tan, sqrt, pi
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import Lasso, LassoCV
from sklearn.metrics import mean_absolute_error, mean_squared_error
from matplotlib import pyplot as plt
#import japanize_matplotlib
import seaborn as sns


from tklib.tkobject import tkObject
from tklib.tkutils import pint, pfloat, add_dict
from tklib.tkutils import is_exist, is_file, is_dir, rename_file, delete_file, split_file_path, modify_path
from tklib.tkfile import tkFile
from tklib.tkexcel import tkExcel


class tkLearning(tkObject):
    """ 
    tklib Machine Learning class
    """

    def __init__(self, dataframe = None, **args):
        self.infile        = None
        self.dataframe     = None
        self.icolumns_dict = {}
        self.avg           = []
        self.std           = []

        self.initialize(**args)
        
        if dataframe is not None:
            self.dataframe = dataframe
            self.build_icolumns_dict()

#    def __del__(self):
#        print("{} destroyed".format(self.name))

    def __str__(self):
#        return "tkLearning object"
        return self.dataframe.__str__()

    def initialize(self, **args):
        self.update(**args)

    def df(self, dataframe = None):
        if dataframe is None:
            return self.dataframe
        else:
            return dataframe

    def isnull(self):
        return self.dataframe.isnull()

    def isna(self):
        return self.dataframe.isna()

    def isnan(self):
        return self.dataframe.isna()

    def check_type(self, val):
        if type(val) is str:
            return 'str'
        if val is None or val == '' or np.isnan(val):
            return 'null'
        return 'val'

    def is_valid_value(self, val):
        if val is None or val == '':
            return False
#        if type(val) is not float and type(val) is not int and type(val) is not 'numpy.float64':
#            return False
        if np.isnan(val):
            return False
        return True

    def check_valid_values(self, list):
        for iv in range(len(list)):
            v = list[iv]
            if not self.is_valid_value(v):
                print(f"Error: ic={ic}({label}) at index={iv}: val={v}: type=", type(v))
                return False
        return True

    def check_not_include_nan(self, list):
        for iv in range(len(list)):
            v = list[iv]
            if np.isnan(v):
                print(f"Error: ic={ic}({label}) at index={iv}: val={v}: type=", type(v))
                return False
        return True

    def value(self, irow, icol, dataframe = None):
        if dataframe is None:
            dataframe = self.dataframe

        if type(icol) is str:
            return dataframe[icol][irow]
        else:
            return dataframe.iat[irow, icol]

    def get_ndarray(self, dataframe = None):
        df = self.df(dataframe = dataframe)
        return df.to_numpy()

    def tonumpy(self, dataframe = None):
        df = self.df(dataframe = dataframe)
        return df.to_numpy()

    def tolist(self, dataframe = None):
        df = self.df(dataframe = dataframe)
        return df.to_numpy().tolist()

    def get_list(self, key, dataframe = None):
        df = self.df(dataframe = dataframe)
        return [df[key][idx] for idx in range(len(df[key]))]

    def get_col_list(self, ic, dataframe = None):
        df = self.df(dataframe = dataframe)

        if type(ic) is int:
            list_org = df.to_numpy().tolist()
            return [df.iat[idx, ic] for idx in range(len(df.index))]
#            return [list_org[idx][ic] for idx in df.index]
#            return [list_org[idx][ic] for idx in range(self.ndata())]
        else:
#            l = []
#            print("len=", len(df[ic]))
#            for idx in range(len(df[ic])):
#                print("a=", idx, df[ic][idx])
#                l.append(df[ic][idx])
            return [df[ic][idx] for idx in df.index]
#            return [df[ic][idx] for idx in range(len(df[ic]))]

    def copy_dataframe(self, dataframe = None):
        return self.df(dataframe = dataframe).copy()

    def copy(self, dataframe = None):
        return tkLearning(dataframe = self.df(dataframe = dataframe).copy())

    def query(self, expression, dataframe = None, renumber_index = False):
        ml = tkLearning(dataframe = self.df(dataframe = dataframe).query(expression))
        ml.df().dropna(axis = 0)
        if renumber_index:
            ml.df().index = [i for i in range(len(ml.df().index))]
        return ml

    def add_contrib_var(self, varname, varname_source, func):
        x_source = self.var_contrib[varname_source]['x']
        x        = func(x_source)
        idx      = self.labels().index(varname)
        self.var_contrib[varname] = { "i": idx, "x": x }
        if self.var_related_vars.get(varname, None) is None:
            self.var_related_vars[varname] = [varname, varname_source]
        else:
            self.var_related_vars[varname].append(varname_source)
        self.var_related_vars[varname_source].append(varname)
#        print(f"self.var_related_vars[{varname}]=", self.var_related_vars[varname])

    def insert(self, loc = None, column = None, var_source = None, value = None, dataframe = None, func = None, add_contrib_var = False):
        df = self.df(dataframe = dataframe)
        if value is None:
            value = df[var_source]
#        print("val=", value)
#        print("c=", column)

        if loc is None:
            loc = len(df.columns)

        if func is None:
            v = value.copy()
        else:
            v = func(value)
#            v = func(value.copy())

#        print("v=", v)        
        df.insert(loc = loc, column = column, value = v)

        if add_contrib_var:
            self.add_contrib_var(column, var_source, func)


    def extract_columns(self, keys, dataframe = None):
        df = self.df(dataframe = dataframe)
        lists  = []
        labels = []
        print("")
#        print("df=", df)
        for k in keys:
#            print("k=", k)
            list = self.get_col_list(k, dataframe = df)
            lists.append(list)
            if type(k) is int:
                labels.append(df.columns[k])
            else:
                labels.append(k)
        df = pd.DataFrame(np.array(lists).T, columns = labels)
        return tkLearning(dataframe = df)

    def extract_valid_labels(self, labels, dataframe = None):
        df = self.df(dataframe = dataframe)
        valid_keys = df.columns
        keys = []
        for key in labels:
            if key in valid_keys:
                keys.append(key)
        return keys

    def drop(self, labels, axis, dataframe = None):
#        print("dataframe=", dataframe)
        if dataframe is None or dataframe is self.dataframe:
            valid_keys = self.extract_valid_labels(labels)
#            print("l=", labels)
#            print("v=", valid_keys)
            self.dataframe = self.dataframe.drop(columns = valid_keys, axis = axis)
#            print("v.dropped=", self.dataframe.columns)
            return self.dataframe
        else:
            valid_keys = self.extracgt_valid_labels(labels, dataframe = dataframe)
            dataframe_new = df.drop(columns = valid_keys, axis = axis)
            return dataframe_new
        
    def index(self, dataframe = None):
        if dataframe is None:
            return [i for i in self.dataframe.index]
        else:
            return [i for i in dataframe.index]

    def columns(self, dataframe = None):
        if dataframe is None:
            return [label for label in self.dataframe.columns]
        else:
            return [label for label in dataframe.columns]

    def labels(self, dataframe = None):
        return self.columns(dataframe = dataframe)

    def nvars(self, dataframe = None):
        return len(self.columns(dataframe = dataframe))

    def ndata(self, dataframe = None):
        return len(self.index(dataframe = dataframe))

    def build_icolumns_dict(self, dataframe = None):
        if dataframe is None:
            dataframe = self.dataframe

        self.icolumns_dict = {}
        for i in range(len(dataframe.columns)):
            label = dataframe.columns[i]
            self.icolumns_dict[label] = i

        return self.icolumns_dict

    def head(self):
        self.dataframe.head()

    def get_dummies(self, data = None, drop_first = True):
        ml_dummy = copy.copy(self)
#        ml_dummy = copy.deepcopy(self)
        data_dummy = pd.get_dummies(data = data, drop_first = drop_first)
        ml_dummy.dataframe = data_dummy
#        return tkLearning(dataframe = data_dummy)
        return ml_dummy

    def average(self, dataframe = None):
        df = self.df()
        return df.mean()

    def stddev(self, dataframe = None, ddof = 1):
        df = self.df()
        return df.std(ddof = ddof)
        
    """
    def sum(self, list):
        return sum(list)

    def sum2(self, list):
        return sum([x**2 for x in list])

    def average(self, list):
        return np.mean(list)

    def sum_dev(self, list):
        avg = self.average(list)
        return sum([x - avg for x in list])

    def sum_dev2(self, list):
        avg = self.average(list)
        return sum([(x - avg)**2 for x in list])

    def variance(self, list, ddof = 1):
        return np.var(list, ddof = ddof)

    def correlation(self, list_list):
        return np.corrcoef(list_list)

    def covariance(self, list_list, ddof = 1):
        if ddof == 1:
            return np.cov(list_list, bias = True)
        else:
            return np.cov(list_list, bias = False)

    def stddev(self, list, ddof = 1):
        return np.std(list, ddof = ddof)
    """

    def prepare_var_inf(self, dataframe = None, nmesh = 101, add_contrib_var = True):
        print("")
        print("Prepare variables information:")
        df = self.df()

        labels = self.labels()
        nvars  = len(labels)
        data   = df.to_numpy().T
#        print("len=", len(data), len(data[0]))

        self.var_range        = []
        self.var_contrib      = {}
        self.var_related_vars = {}
        print("  Variable range:")
        for i in range(nvars):
            varname = labels[i]

            minmax = [min(data[i]), max(data[i])]
            self.var_range.append(minmax)
            if type(minmax[0]) is str:
                print(f"  {varname:>10}: {minmax[0]:10} - {minmax[1]:10}")
                x  = np.zeros(nmesh)
            else:
                print(f"  {varname:>10}: {minmax[0]:10.4g} - {minmax[1]:10.4g}")

                dx = (minmax[1] - minmax[0]) * 0.1
                x  = np.linspace(minmax[0] - dx, minmax[1] + dx, nmesh)

            if add_contrib_var:
                self.var_contrib[varname]      = { "i": i, "x": x }
                self.var_related_vars[varname] = [varname]

    def normalize_minmax(self, dataframe = None):
        if dataframe is None:
            dataframe = self.dataframe

        columns = dataframe.columns
        list_std = dataframe.to_numpy().tolist()
        ndata    = len(dataframe.index)
        min_list = []
        max_list = []
        for ic in range(len(columns)):
            label = columns[ic]
            list  = self.get_col_list(ic)
            if not self.check_not_include_nan(list):
                exit()

            _min = min(list)
            _max = max(list)
            min_list.append(_min)
            max_list.append(_max)
            mid = (_max + _min) * 0.5
            w = _max - _min

            if w == 0.0:
                print("")
                print("Error: The variable at ic={ic}({columns[ic]}) does not have variation")
                exit()

            for iv in range(len(list)):
                list_std[iv][ic] = 2.0 * (list[iv] - mid) / w

        df = pd.DataFrame(list_std, columns = dataframe.columns)
        ml = copy.copy(self)
        ml.dataframe = df
#        ml = tkLearning(dataframe = df)
        ml.min = min_list
        ml.max = max_list
        
        return ml, None

    def normalize_standardize(self, dataframe = None, mode = 'standardize'):
        df = self.df()
        sc = StandardScaler()
        sc.fit(df)
        ml = copy.copy(self)
        ml.dataframe = pd.DataFrame(sc.transform(df), columns = df.columns)
        return ml, sc




        df = self.df(dataframe = dataframe)
        df_ = (df - df.mean()) / df.std(ddof = 1)
        ml = copy.copy(self)
        ml.dataframe = df
        print("ml")
        print(ml)
        print("labels=", ml.labels())
        exit()

        return ml

    def normalize(self, dataframe = None, mode = 'standardize'):
        if mode == 'standardize':
            return self.normalize_standardize(dataframe = dataframe)
        elif mode == 'minmax':
            return self.normalize_minmax(dataframe = dataframe)

    def read_excel(self, infile):
        self.infile = infile
        self.dataframe = pd.read_excel(infile)
        self.build_icolumns_dict(self.dataframe)
        return self.dataframe

    def read_csv(self, infile):
        self.infile = infile
        self.dataframe = pd.read_csv(infile)
        self.build_icolumns_dict(self.dataframe)
        return self.dataframe

    def write_excel(self, fname):
        columns = self.columns()

        xls_std = tkExcel(path = fname, mode = 'w', OpenFile = 1, CloeFile = False, data_only = True)
        ws_std  = xls_std.ws
        for ic in range(len(columns)):
            ws_std.cell(row = 1, column = ic+1).value = columns[ic]

        for ic in range(len(columns)):
            label = columns[ic]
            for iv in range(self.ndata()):
                ws_std.cell(row = iv+2, column = ic+1).value = self.value(iv, ic)
        xls_std.close()

    def plot_contrib(self, varname, ax, scaler, model, 
                x_train_std, y_train, size_train, x_test_std, y_test, size_test, 
                sizes = (2, 100), s = 5.0, alpha = 0.3, nmesh = 101, 
                plot = None, df = None, train_label = None, test_label = None, idx_train = None, idx_test = None):
        labels = self.labels()
        nvars  = len(labels)
        idx    = labels.index(varname)
        zero_list = np.zeros(nmesh, dtype = float)

        i_desc = [idx]
#        print("labels=", labels)
#        print("var=", varname, self.var_related_vars[varname])
        for var in self.var_related_vars[varname]:
            try:
                idx1 = labels.index(var)
            except:
                idx1 = None
            if idx1 is not None and idx1 != idx:
                i_desc.append(idx1)
#        print("  i_desc:", self.var_related_vars[varname], i_desc)

        if i_desc[0] == 0:
            print(f" {'var':^20}      ", end = '')
            for  l in labels:
                print(f"{l:10} ", end = '')
            print("")
        x_list = []
        print(f"  [{varname:>20}]  ", end = '')
        for i in range(nvars):
            varnamei = labels[i]
            if i in i_desc:
                data = self.var_contrib[varnamei]['x']
                data_std = (data - scaler.mean_[i]) / scaler.scale_[i]
#                data_std = (data - scaler.mean_[i]) / sqrt(scaler.var_[i])
                x_list.append(data_std)
#                x_list.append(self.var_contrib[varnamei]['x'])
                print(f"{'given':10} ", end  = '')
            else:
                x_list.append(zero_list)
                print(f"{'0':10} ", end  = '')
        print("")

        x_sim_std = np.vstack(x_list).T
#        x_sim     = np.vstack(x_list).T
#        x_sim_std = scaler.transform(x_sim)
#        for i in range(nvars):
#            varnamei = labels[i]
#            if i not in i_desc:
#                for j in range(nmesh):
#                    x_sim_std[j][i] = 0.0

#        tkLearning(dataframe = pd.DataFrame(x_sim, columns = labels)).write_excel(f"{varname}.xlsx")
#        tkLearning(dataframe = pd.DataFrame(x_sim_std, columns = labels)).write_excel(f"{varname}_std.xlsx")

        y_sim     = model.predict(x_sim_std)
#        print("i=", i, i_desc[0], min(x_sim_std.T[i_desc[0]]), max(x_sim_std.T[i_desc[0]]))
        ax.plot(x_sim_std.T[i_desc[0]], y_sim)
#        ax.scatter(x = x_train_std, y = y_train, alpha = 0.4, c = 'red',  s = markersize)
#        ax.scatter(x = x_test_std,  y = y_test,  alpha = 0.3, c = 'blue', s = markersize)
        sns.scatterplot(ax = ax, x = x_train_std, y = y_train, size = size_train,
                sizes = sizes, s = s, alpha = alpha, color = 'red', legend = False)
        sns.scatterplot(ax = ax, x = x_test_std, y = y_test, size = size_test,
                sizes = sizes, s = s, alpha = alpha, color = 'blue', legend = False)
        ax.set_xlabel(varname)
#        ax.x_train = x_train_std
#        ax.y_train = y_train
#        ax.x_test  = x_test_std
#        ax.y_test  = y_test
        plot.add_axis({"axis": ax, "label": train_label, "x": x_train_std, "y": y_train, "df": df, "idx": idx_train})
        plot.add_axis({"axis": ax, "label": test_label,  "x": x_test_std,  "y": y_test,  "df": df, "idx": idx_test})

    def plot_contributions(self, scaler, model, 
                x_train_std, y_train, size_train, x_test_std, y_test, size_test, 
                nmesh = 101, s = 5.0, figsize = (10, 10), plot = None, df = None, idx_train =None, idx_test = None):
        labels = self.labels()
        nvars  = len(labels)

        ncol = int(sqrt(nvars) + 1)
        nrow = int(nvars / ncol + 1.0001)
#        print("ncol,nrow=", ncol, nrow)

        fig, ax = plt.subplots(nrow, ncol, figsize = figsize)

        for i in range(nvars):
            ix = int(i / ncol)
            iy = i % ncol
            label = labels[i]
#            print(f"  plot [{label}] in ax[{ix}, {iy}]")
            self.plot_contrib(label, ax[ix, iy], scaler, model, 
                        x_train_std.T[i], y_train, size_train, x_test_std.T[i], 
                        y_test, size_test, s = s, nmesh = nmesh, plot = plot, df = df, 
                        train_label = f"train[{ix}][{iy}]", test_label = f"test[{ix}][{iy}]",
                        idx_train = idx_train, idx_test = idx_test)

        plt.tight_layout()

        return fig, ax
        