import sys
import os
import numpy as np
from numpy import exp, log, log10, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, sqrt, abs
import pandas as pd
from scipy.optimize import minimize
import matplotlib.pyplot as plt
import matplotlib.widgets as wg


from tklib.tkutils import replace_path, getarg, getintarg, getfloatarg, pint, pfloat
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams
from tklib.tksci.tksci import asin, acos, atan, degcos, degsin, degtan, degacos, degasin, degatan, arcsin, arccos, arctan
from tklib.tksci.tksci import factorial, gamma, combination, eVTonm, nmToeV, Bn
from tklib.tksci.tksci import h, h_bar, hbar, e, kB, NA, c, pi, pi2, torad, todeg, basee
from tklib.tksci.tksci import me, mp, mn
from tklib.tksci.tksci import u0, e0, a0, R, F, g, G
from tklib.tksci.tksci import HartreeToeV, RyToeV, KToeV, eVToK, JToeV, eVToJ, Debye
from tklib.tksci.tkFit import tkFit
from tklib.tksci.tkFit_m import tkFit_m


#=======================================
# プログラム開始
#=======================================
def initialize(app):
    cparams = tkParams()
    app.cparams = cparams

#nelder-mead    Downhill simplex
#powell         Modified Powell
#cg             conjugate gradient (Polak-Ribiere method)
#bfgs           BFGS法
    cparams.method = "bfgs"

#==========================================
# Source parameters to be fitted
#==========================================
    cparams.infile = 'peak.xlsx'
    cparams.nx = None

    cparams.mode = 'fit'
    
    cparams.func = "p[0]*exp(-((x[0]-p[1])/p[2])**2)"
    cparams.p0s  = "1.3,0.6,0.1"
    cparams.fit_range = "-1e100:1e100, -1e100:1e100, -1e100:1e100"

    cparams.olabel = ''
    cparams.xlabel = ''
    cparams.ylabel = ''
    cparams.zlabel = ''

# 数値微分する際の変数の微小変位
    cparams.h = 0.01

    cparams.maxiter = 100
    cparams.tol = 1.0e-5

#==========================================
# Graph parameters
#==========================================
    cparams.ngdata  = 51
    cparams.xgmin   = -4.0
    cparams.xgmax   =  4.0
    cparams.tsleep  = 0.01
    cparams.figsize =  (10, 5)
    cparams.fontsize        = 16
    cparams.legend_fontsize = 10
    cparams.outputinterval      = 5
    cparams.graphupdateinterval = 10

def update_vars(app):
    cparams = app.cparams

    argv = sys.argv
#    n = len(argv)
#for i in range(n):
#    print("i=", i, argv[i])

    cparams.mode      = getarg( 1, cparams.mode)
    cparams.infile    = getarg( 2, cparams.infile)
    cparams.method    = getarg( 3, cparams.method)
    cparams.func      = getarg( 4, cparams.func)
    cparams.p0s       = getarg( 5, cparams.p0s)
    cparams.fit_range = getarg( 6, cparams.fit_range)
    cparams.maxiter   = getintarg  (7, cparams.maxiter)
    cparams.tol       = getfloatarg(8, cparams.tol)
    cparams.h         = getfloatarg(9, cparams.h)
    cparams.olabel    = getarg(10, cparams.olabel)
    cparams.xlabel    = getarg(11, cparams.xlabel)
    cparams.ylabel    = getarg(12, cparams.ylabel)
    cparams.zlabel    = getarg(13, cparams.zlabel)

#    app.add_argument(opt = "-s", type = "str", var_name = 'script_list_name',  opt_str = "-s=script_list_name", desc = 'Script list file',
#                     defval = "default", optional = True);
#    app.add_argument(opt = "-i", type = "float", var_name = 'idx', opt_str = "-i=integer", desc = 'Index to speify alpha', 
#                     defval = 5, optional = False);
#    app.add_argument(opt = None, type = "str", var_name = 'infile', opt_str = "infile",  desc = 'Input file', 
#                     defval = "input.txt", optional = False);
#    app.add_argument(opt = None, type = "str",   var_name = 'nmax', opt_str = "nmax",      desc = 'Max number of interation', 
#                    defval = 5, optional = False);

#    args_opt, args_idx, args_vars = app.read_args(vars = cparams.__dict__, check_allowed_args = True)
#    if args_opt is None:
#        error_no = args_idx
#        error_message = args_vars
#        app.terminate("\n\n"
#                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
#                   + f"!  {error_message}\n"
#                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",
#                   usage = app.usage)

#    app.set_usage(usage_str)


#==========================================
# functions
#==========================================
def cal_y(xd, pk, app):
    y = eval(app.cparams.func, globals(), {"x": xd, "p": pk})
    return y

def fit(app):
    cparams = app.cparams

    app.log_path = app.replace_path(cparams.infile)
    app.redirect(["stdout", app.log_path], "w")

    cparams.p0s = [float(s) for s in cparams.p0s.split(',')]
    r = []
    for s in cparams.fit_range.split(','):
        xmin, xmax = s.split(':')
        r.append([float(xmin), float(xmax)])
    cparams.fit_range = r

    cparams.out_file         = replace_path(cparams.infile, "{dirname}/{filebody}-fitting.xlsx")
    cparams.convergence_file = replace_path(cparams.infile, "{dirname}/{filebody}-convergence.xlsx")
    cparams.parameter_file   = replace_path(cparams.infile, "{dirname}/{filebody}.prm") 
    
    print("")
    print("Fitting program to given function")
    print(f"mode            : {cparams.mode}")
    print(f"method          : {cparams.method}")
    print(f"input file      : {cparams.infile}")
    print(f"  fit_range     : {cparams.fit_range}")
    print(f"output file     : {cparams.out_file}")
    print(f"convergence file: {cparams.convergence_file}")
    print(f"parameter file  : {cparams.parameter_file}")
    print(f"func            : {cparams.func}")
    print(f"Target          : {cparams.olabel}")
    print(f"x               : {cparams.xlabel}")
    print(f"y               : {cparams.ylabel}")
    print(f"z               : {cparams.zlabel}")
    print(f"tol             : {cparams.tol}")
    print(f"maxiter         : {cparams.maxiter}")

    if '***' in cparams.method:
        input("\nError: Choose method\n")
        exit()

    if '***' in cparams.func:
        input("\nError: Choose or input function\n")
        exit()

    fit = tkFit_m(tol = cparams.tol, nmaxiter = cparams.maxiter, 
                print_interval = cparams.outputinterval, plot_interval = cparams.graphupdateinterval)

    read_labels = []
    if cparams.xlabel != '---':
        read_labels.append(cparams.xlabel)
    if cparams.ylabel != '---':
        read_labels.append(cparams.ylabel)
    if cparams.zlabel != '---':
        read_labels.append(cparams.zlabel)

    print("")
    print(f"Read [{cparams.infile}]")
    ranges = np.array(cparams.fit_range).T
    print("  For ranges:", ranges)
    fit.read_data(cparams.infile, xlabels = read_labels, ylabel = cparams.olabel, 
                    xmins = ranges[0], xmaxs = ranges[1], usage = lambda: usage(app))

    xd = fit.x_list
    xd_labels = fit.xlabels
    nx    = len(xd)
    ndata = len(xd[0])
    nvars = len(cparams.p0s)

    print("")
    print("x ranges:")
    for i in range(len(xd)):
        xmin = min(xd[i])
        xmax = max(xd[i])
        print(f"  x[{i}]: {xmin} - {xmax}")
        print(f"    fit in: {cparams.fit_range[i][0]} - {cparams.fit_range[i][1]}")

    fit.varname = [f'p({i})' for i in range(nvars)]
    fit.unit    = ['' for i in range(nvars)]
    fit.pk      = [pfloat(cparams.p0s[i]) for i in range(nvars)]
    fit.optid   = [1 for i in range(nvars)]
#    print("")
#    print(f"initial values  : {cparams.p0s}")
#    print( "  pk:", fit.pk)

    fit.func = lambda x_list, pk: cal_y(x_list, pk, app)

#=============================
# 初期値関数
#=============================
    print("")
    print("Calculate initial function:")
    yini = fit.cal_ylist(fit.pk)
    fini = fit.minimize_func(fit.pk)
    print("")
    fit.print_data(heading = "Initial functions", yini = yini)

#=============================
# グラフの表示
#=============================
    print("")
    print("plot")
    cparams.fig, axes = plt.subplots(1, 2, figsize = cparams.figsize)
#    axes = [axes]
#    axes = axes.flatten()

    fit.initial_plot(data_axis = axes[0], error_axis = axes[1], yini = yini, fmin = fini, 
                fig = cparams.fig, fontsize = cparams.fontsize)

    if cparams.mode == 'sim':
        app.terminate("", pause = True)


#=============================
# Optimization
#=============================
    print("")
    print(f"Nonlinear least-squares fitting by [{cparams.method}]")
    fit.print_variables()
    print(f"  tol={fit.tol}")
    print(f"  nmaxiter={fit.nmaxiter}")

    pfin, ffin, success, res = fit.minimize(cparams.method)
    if success:
        print("")
        print(f"\nConverged at iteration: {res.nit}")
    else:
         print(f"***Warning: Function did not converge")

    print("Final parameters")
    for i in range(nvars):
         print(f"  {fit.varname[i]:10}: {pfin[i]:10.4g} {fit.unit[i]}")
    print(f"    f={ffin:12.6g}")

    print("")
    print("Optimized at S2={:12.6g}:".format(ffin))
    fit.print_variables()


#========================================
# Final result
#========================================
    yfin = fit.cal_ylist(pfin)
    print("")
#    fit.print_data(heading = "Final functions", yini = yini, yfin = yfin)
    fit.print_scores(heading = "\nScores between y(input) and y(fit)", y1 = fit.y, y2 = yfin)

    print("")
    print(f"Save results to [{cparams.out_file}]")
    fit.to_excel(cparams.convergence_file, 
            [*xd_labels, cparams.olabel, 'initial', 'final'], [*xd, fit.y, yini, yfin])

    print("")
    print(f"Save convergence process to [{cparams.convergence_file}]")
    fit.to_excel(cparams.convergence_file, 
            ['iter', 'MSE'], [fit.iter_list, fit.fmin_list])

    print("")
    print(f"Save parameters to [{cparams.parameter_file}]")
    cparams.fit_range = "{}".format(cparams.fit_range)
    cresults = tkParams()
    cresults.initial_values = f"{cparams.p0s}"
    cresults.final_values   = "{}".format(list(res.x))
    cresults.MAE            = res.fun
    cresults.iteration      = res.nit
    del cparams.p0s

    cresults.initial_values = cresults.initial_values.replace('[', '').replace(']', '')
    cresults.final_values   = cresults.final_values.replace('[', '').replace(']', '')
    cparams.save_parameters(cparams.parameter_file, section = 'Parameters', sort_by_keys = False, update_commandline = False, IsPrint = False)
    cresults.save_parameters(cparams.parameter_file, section = 'Fitting', sort_by_keys = False, update_commandline = False, IsPrint = False)


#=============================
# グラフの表示
#=============================
    print("")
    print("Plot optimized")
    fit.finalize_plot(yfin, iter = res.nit, fmin = ffin)

    app.terminate("", pause = True)


#==========================================
# Main routine
#==========================================
def main():
    app = tkApplication()

    print(f"Initialize parameters")
    initialize(app)
    print(f"Update parameters by command-line arguments")
    update_vars(app)

    fit(app)


if __name__ == "__main__":
    main()
