import os
import platform
import re
import numpy as np
from numpy import log, abs, sqrt, exp
from scipy import interpolate
import pandas as pd
import openpyxl
import time
from datetime import datetime


from tklib.tkfile import tkFile
from tklib.tkutils import pint, pfloat
from tklib.tkparams import tkParams
from tklib.tksci.tkFit_lib import conv_input, save_data, add_history


config = tkParams()

def initialize_minimize_func():
    os = platform.system()
    pass

def read_data(infile, filetype = ''):
    print(f"Read data from [{infile}]")
#pandasを使って、Excelファイルからinfileをdfに読み込む
    if infile.endswith('.xlsx'):
        df = pd.read_excel(infile)
    else:
        df = pd.read_csv(infile, delimiter = '\t')
    labels = list(df.columns)

    if filetype == '':
        E_list = df[labels[0]].values
        DOS_list = df[labels[1]].values
        return labels, [E_list], [DOS_list]
    else:
        E_list = df[labels[0]].values
        data_list = []
        for l in labels[1:]:
            data_list.append(df[l].values)
        return labels, [E_list], data_list

def ycal_list(xd_list, xk0, fit = None, run = True):
    xk = xk0.copy()
    if fit and fit.kpenalty:
        kp = fit.kpenalty
        kmin = fit.kmin
        kmax = fit.kmax
        nkp = len(kp)
        for i in range(nkp):
            if xk[i] < kmin[i]:
                print(f"**Warning: [{fit.varname[i]}={xk[i]}] is smaller than [{kmin[i]}]. Replace with [{kmin[i]}].")
                xk[i] = kmin[i]
            elif xk[i] > kmax[i]:
                print(f"**Warning: [{fit.varname[i]}={xk[i]}] is larger than [{kmax[i]}]. Replace with [{kmax[i]}].")
                xk[i] = kmax[i]

    ndata = len(fit.x_list[0])
    y_list = np.zeros(ndata)
    nvar = len(xk)
    for x in fit.x_list[0]:
        ip = 0
        while True:
            if ip >= nvar:
                break

            if fit.varname[ip][0] == 'g':
                e0 = xk[ip]
                n0 = xk[ip+1]
                w  = xk[ip+2]
                for i, E in enumerate(fit.x_list[0]):
                    kexp = (E - e0) / w
                    y_list[i] += n0 * exp(-kexp * kexp)
                ip += 3
            elif fit.varname[ip][0] == 'e':
                n0 = xk[ip]
                w  = xk[ip+1]
                for i, E in enumerate(fit.x_list[0]):
                    y_list[i] += n0 * exp(-E / w)
                ip += 2
            else:
                print("\nError: Invalid head char in varname #{ip} [{app.varname[i]]}]\n")
                exit()
    if fit:
        fit.xc_list = fit.x_list
        fit.yc_list = [y_list.tolist()]
    
    return fit.yc_list

def save_parameters(fit, cparams):
    print(f"Save parameters to [{cparams.parameterfile}]")
#    cparams.save_parameters(cparams.parameterfile, section = "Preferences")
    flag_keys = ["fplot", "fhistory", "ffitfiles"]
    condiction_keys = ["mode", "method", "jac", "nmaxiter", "tol", "y_convert"]
    exclude_keys = condiction_keys + flag_keys + fit.varname \
                    + ["infile", "calfile", "logfile", "outfile", "parameterfile", "stopfile", "historyfile", "fitfile"]
    cparams.save_parameters(cparams.parameterfile, section = "Preferences", exclude_keys = exclude_keys)
    cparams.save_parameters(cparams.parameterfile, section = "Flags",      keys = flag_keys)
    cparams.save_parameters(cparams.parameterfile, section = "Condition",  keys = condiction_keys)
    fit.save_parameters    (cparams.parameterfile, section = "Parameters", keys = fit.varname)

def callback(pk, fit, run = True):
    cparams = fit.cparams

    if fit.stop_flag:
        return False

    exist_stopfile = os.path.isfile(cparams.stopfile)
    print("exist_stopfile:", cparams.stopfile, exist_stopfile)
    if exist_stopfile:
        print()
        print(f"Message: Found stop file [{cparams.stopfile}] for early stop.")
        print()
        fit.stop_flag = True
        return False

    fplot = getattr(fit, 'fplot', True) # and fit.fplot
    if fplot:
        w = fit.plt.get_current_fig_manager().window
        if w != fit.window:
            print()
            print(f"Message: Graph window is closed for early stop.")
            print()
            fit.stop_flag = True
            return False

    recoverpk = fit.recover_parameters(pk, set_member = False)

    if fit.iter % fit.print_interval == 0:
        print(f"iter: {fit.iter}")
        n = len(recoverpk)
        for i in range(n):
            print(f"  {fit.varname[i]:10}: {recoverpk[i]:10.4g} {fit.unit[i]}")
        f = fit.minimize_func(pk)
        print(f"    f={f:12.6g}")

        if fit.get('fmin_list', None) is not None:
            fit.iter_list.append(fit.iter + 1)
            fit.fmin_list.append(f)

    if fplot and fit.iter % fit.plot_interval == 0:
        ycal_list = fit.cal_ylist(recoverpk, run = run)

        x_list  = fit.__dict__.get('x_list', None)
        xlabels = fit.__dict__.get('xlabels', None)
        nx = len(x_list)
        if nx == 1:
            for i in range(len(fit.y_list)):
                axis = fit.data_axes[i]
                fit.fit_data_list[i][0].set_data(x_list[0], ycal_list[i])
        else:
            for i in range(len(fit.y_list)):
                axis = fit.data_axes[i]
                fit.fit_data_list[i][0].set_data(range(len(ycal_list[i])), ycal_list[i])

        if fit.get('fmin_list', None) is not None:
            fit.error_data[0].set_data(fit.iter_list, fit.fmin_list)
            fit.error_axis.set_xlim([min(fit.iter_list), max(fit.iter_list)])
            fit.error_axis.set_ylim([min(fit.fmin_list), max(fit.fmin_list)])

        fit.plt.tight_layout()
        fit.plt.subplots_adjust(top = fit.plot_region[0], bottom = fit.plot_region[1])
        fit.plt.pause(0.0001)

    fit.iter += 1

    fit.retrieve_parameter_list(recoverpk, fit.varname, target = cparams)
    save_parameters(fit, cparams)

    return True

def minimize_func(xk, x_list = None, y_list = None, w_list = None, fit = None, run = True):
    if fit.stop_flag:
        return 1.0e300

    if x_list is None:
        x_list = fit.x_list
    if y_list is None:
        y_list = fit.y_list

    print()
    nx = len(x_list[0])
    xkr = fit.recover_parameters(xk, set_member = False)
    xkr0 = xkr.copy()
    kp   = fit.kpenalty
    p_tot = 0.0
    if kp is not None:
        kmin = fit.kmin
        kmax = fit.kmax
        nkp = len(kp)
        for i in range(nkp):
            if xkr[i] < kmin[i]:
                d = xkr[i] - kmin[i]
                p = kp[i] * d * d
                p_tot += p
                print(f"**Warning: [{fit.varname[i]}={xkr[i]}] is smaller than [{kmin[i]}]. Add penalty [{p:10.3g}] to fmin")
                xkr[i] = kmin[i]
            elif xkr[i] > kmax[i]:
                d = xkr[i] - kmax[i]
                p = kp[i] * d * d
                p_tot += p
                print(f"**Warning: [{fit.varname[i]}={xkr[i]}] is larger than [{kmax[i]}]. Add penalty [{p:10.3g}] to fmin")
                xkr[i] = kmax[i]

    yc_list = ycal_list(x_list, xkr, fit, run = run)
    fmin = 0.0
    eps = 1.0e-20
    for i in range(nx):
        ymeas = fit.y_list[0][i]
        ysim  = yc_list[0][i]
        if fit.y_convert == 'log':
            if ymeas <= 0.0:
                ymeas = eps
            if ysim <= 0.0:
                ysim = eps
            d  = log(ymeas) - log(ysim)
        else:
            d  = ymeas - ysim
        fmin += d * d

    fmin = sqrt(fmin / nx) + p_tot

    if fit.cparams.fhistory:
        print(f"Add fmin and parameters to [{fit.historyfile}]")
        iter = add_history(fit.historyfile, fit.varname, xkr0, fmin)
        if iter is not None:
            print(f"#{iter}: fmin={fmin:12.4g} ", end = '')
            for v in xk:
                print(f" {v:12.6g}", end = '')
            print()
        else:
            iter = None
            print(f"  Warning: Failed to add to {fit.historyfile}]")
            print(f"           Check write permission / disk space etc")

    if fit.cparams.ffitfiles and iter is not None:
        save_path = f'fit{iter:04}.xlsx'
        print(f"Save last input and calculation data  to [{save_path}]")
        if hasattr(fit, 'yini_list'):
            ret = save_data(save_path, ['E', 'DOS(in)', 'DOS(ini)', 'DOS(fin)'], [fit.x_list[0], fit.y_list[0], fit.yini_list[0], yc_list[0]])
        else:
            ret = save_data(save_path, ['E', 'DOS(in)', 'DOS(fin)'], [fit.x_list[0], fit.y_list[0], yc_list[0]])
        if not ret:
            print(f"  Warning: Failed to save to [{save_path}]")
            print(f"           Check write permission / disk space etc")

    fit.fmin = fmin
    fit.yc_list = yc_list

    return fmin
