import os
import sys
import csv
import numpy as np
from numpy import sin, cos, tan, pi, exp, log, sqrt
#from scipy.optimize import minimize
#import pandas as pd
from matplotlib import pyplot as plt


from tklib.tksci.tksci import log10, e, kB
from tklib.tkutils import getarg, getintarg, getfloatarg, pconv_by_type, print_line, print_data, joinf
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams
from tklib.tkvariousdata import tkVariousData
#from tklib.tkfilter import tkFilter
from tklib.tkgraphic.tkplotevent import tkPlotEvent
from tklib.tksci.tkFit_mxy import tkFit_mxy

from minimize_func_CV import read_data, save_data, ycal_list, minimize_func


"""
Fitting by SIMPLEX and Gaussian Process regressions
"""



#========================================================
# 最適化アルゴリズム
#========================================================
method = "cg"
#method = "nelder-mead"
maxiter = 100
tol = 1.0e-5

#==========================================
# Source parameters to be fitted
#==========================================
# 最適化変数の初期値
# Data parameters: I0, x0, w
x0    = [1.3,  0.6, 0.1]

#==========================================
# File configurations
#==========================================
initial_csv = 'initial.csv'
final_csv   = 'final.csv'
conv_csv    = 'convergence.csv'


#==========================================
# Graph parameters
#==========================================
fplot  = 1
ngdata = 51
xgmin  = -4.0
xgmax  =  4.0
ygmin  = -4.0
ygmax  =  4.0
tsleep = 0.3


def initialize():
#================================
# Global variables
#================================
    app          = tkApplication()
    argv, narg   = app.get_argv()
    app.cparams  = tkParams()
    cparams      = app.cparams
    cparams.debug       = 0
    cparams.print_level = 0

    cparams.print_interval = 10
    cparams.plot_interval = 1

    app.varname    = [   "S",    "ND",    "NA", "dn",  "dp"]
    app.unit       = [  "m2",  "cm-3",  "cm-3",  "m",   "m"]
    app.pk_convert = [    "",      "",      "",   "",    ""]
    app.optid      = [      0,       1,      1,    0,     0]
    app.kmin       = [    0.0,    0.01,   0.01,  1e-6, 1e-6]
    app.kmax       = [ 1.0e-2,  1.0e23, 1.0e23,  1e-5, 1e-5]
    app.kpenalty   = [    1.0,     1.0,    1.0,   1.0,  1.0]


    app.add_argument(opt = '--infile', type = "str", var_name = 'infile', opt_str = "input .xlsx file",  desc = 'input Excel file', 
                     defval = 'CV_meas.xlsx', optional = True);
    app.add_argument(opt = "--outfile", type = "str", var_name = 'outfile', opt_str = "output .xlsx file",  desc = 'output Excel file', 
                     defval = 'out.xlsx', optional = True);
    app.add_argument(opt = "--datafile", type = "str", var_name = 'datafile', opt_str = "input/calculated data .xlsx file",  desc = 'Data file', 
                     defval = 'CV.xlsx', optional = True);

    app.add_argument(opt = "--mode", type = "str", var_name = 'mode',  opt_str = "--mode=[fit]", desc = 'task mode',
                     defval = 'fit', optional = True);

    r_electrode = 1095e-6  # 850e-6 m
    app.add_argument(opt = "--S", type = "float", var_name = 'S',  opt_str = "--S=val", desc = 'Electrode area in m2',
                     defval = r_electrode * r_electrode, optional = True);
    app.add_argument(opt = "--ND", type = "float", var_name = 'ND',  opt_str = "--ND=val", desc = 'Donor density in n-layer in cm-3',
                     defval = 1.2e16, optional = True);
    app.add_argument(opt = "--NA", type = "float", var_name = 'NA',  opt_str = "--NA=val", desc = 'Acceptor density in p-layer in cm-3',
                     defval = 1.8e18, optional = True);
    app.add_argument(opt = "--dn", type = "float", var_name = 'dn',  opt_str = "--dn=val", desc = 'n-layer thickiness in m',
                     defval = 3.1e-6, optional = True);
    app.add_argument(opt = "--dp", type = "float", var_name = 'dp',  opt_str = "--dp=val", desc = 'p-layer thickiness in m',
                     defval = 375e-9, optional = True);

    app.add_argument(opt = "--method", type = "str", var_name = 'method',  opt_str = "--method=[nelder-mead]", desc = 'optimization algorism',
                     defval = 'nelder-mead', optional = True);
    app.add_argument(opt = "--nmaxiter", type = "int", var_name = 'nmaxiter',  opt_str = "--nmaxiter=int(>1)", desc = 'maximum interation number for optimization',
                     defval = 1000, optional = True);
    app.add_argument(opt = "--tol", type = "float", var_name = 'tol',  opt_str = "--tol=val", desc = 'eps for optimization',
                     defval = 1.0e-5, optional = True);

    app.add_argument(opt = "--print_interval", type = "int", var_name = 'print_interval',  opt_str = "--print_interval=val", desc = 'print interval',
                     defval = 5, optional = True);
    app.add_argument(opt = "--plot_interval", type = "int", var_name = 'plot_interval',  opt_str = "--plot_interval=val", desc = 'plot interval',
                     defval = 5, optional = True);

#=============================
# Graph configuration
#=============================
    cparams.figsize             = [8, 6]
    cparams.fontsize            = 12
    cparams.legend_fontsize     = 8
    cparams.graphupdateinterval = 10

    return app, cparams


#=============================
# Treat argments
#=============================
def update_vars(app, cparams, apply_default = True):
    args_opt, args_idx, args_vars = app.read_args(vars = cparams, check_allowed_args = True, apply_default = apply_default)
    if args_opt is None:
        error_no = args_idx
        error_message = args_vars
        app.terminate("\n\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
                   + f"!  {error_message}\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",
                   usage = app.usage)


#==========================================
# functions
#==========================================
def fit(app, cparams):
    cparams.parameterfile   = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}.in"])
#    cparams.parameterbkfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-back.in"])
#    cparams.historyfile   = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-history.csv"])
    cparams.historyfile   = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-history.xlsx"])

    print("")
    print(f"Read [{cparams.parameterfile}]")
    cparams.read_parameters(cparams.parameterfile, section = "Preferences",
                ignore_keys = ["logfile", "outfile", "parameterfile", "parameterbkfile"])
#    cparams.print_parameters(heading = "\ncparams:")

    print("")
    print("infile               : {}".format(cparams.infile))
    print("parameter file       : {}".format(cparams.parameterfile))
#    print("parameter backup file: {}".format(cparams.parameterbkfile))
    print("history file       : {}".format(cparams.historyfile))
    print("mode                 : {}".format(cparams.mode))

    print("")
    print("Peak fitting")
    print("mode   : ", cparams.mode)
    print("infile : ", cparams.infile)
    print("outfile: ", cparams.outfile)
    print("parameter file : ", cparams.parameterfile)
    print("data file       : ", cparams.datafile)

    fit = tkFit_mxy(tol = cparams.tol, nmaxiter = cparams.nmaxiter,
                print_interval = cparams.print_interval, plot_interval = cparams.plot_interval)
    fit.infile      = cparams.infile
    fit.datafile    = cparams.datafile
    fit.historyfile = cparams.historyfile

    print()
    print(f"Read input data from [{cparams.infile}]")
    fit.labels, fit.x_list, fit.y_list = read_data(cparams.infile)
    fit.w_list   = None
    fit.xlabels  = [fit.labels[0]]
    fit.ylabels  = [fit.labels[1]]

    print(f"Save input data to [{fit.datafile}]")
    save_data(fit.datafile, fit.labels, fit.x_list + fit.y_list)

    x0 = [cparams.S, cparams.ND, cparams.NA, cparams.dn, cparams.dp]
    fit.varname    = app.varname
    fit.unit       = app.unit
    fit.pk         = x0.copy()
    fit.pk_convert = app.pk_convert
    fit.optid      = app.optid
    fit.kmin       = app.kmin
    fit.kmax       = app.kmax
    fit.kpenalty   = app.kpenalty
    nvars = len(fit.pk)

# xsはxの値のセット
#    fit.func = lambda xs, pk: ycal(xs, pk)
    fit.cal_ylist = lambda pk: ycal_list(fit.x_list, pk, fit)
# x_list[i][j] はi番目のx変数のj番目の値
    fit.minimize_func = lambda pk: minimize_func(pk, fit.x_list, fit.y_list, fit.w_list, fit)
#    fit.callback = lambda pk: callback(pk, fit)
    
    fit.fplot  = fplot
    fit.plt    = plt
    fit.iter   = 0
    fit.xiter  = []
    fit.yfmin  = []
    fit.ycal_list = None
    fit.yfin_list = None

    print()
    fit.print_variables(heading = "Fitting parameters:")
    print(f"Fitting configuration")
    print(f"  method  : {cparams.method}")
    print(f"  tol     : {cparams.tol}")
    print(f"  nmaxiter: {cparams.nmaxiter}")

    optpk = fit.extract_parameters()
    fit.yini_list = fit.cal_ylist(fit.pk)
    fini = fit.minimize_func(optpk)

    print("")
    print(f"Initial function: fmin={fini:10.4g}")
    print_data(labels = ["x", "y_obs", "y_cal"], 
            data_list = [fit.x_list[0], fit.y_list[0], fit.yini_list[0]], 
            label_format = '{:^15}', data_format = '{:>15.4g}', header = "Initial data:", nmax = 20, print_level = 0)
    print("fini=", fini)

    print(f"Save input and initial data to [{fit.datafile}]")
    save_data(fit.datafile, fit.labels + ['ini'], fit.x_list + fit.y_list + fit.yini_list)

#========================================
# plot
#========================================
    fig, axes = plt.subplots(len(fit.y_list), 2, figsize = cparams.figsize)
#    axes = [axes]
#    axes = axes.flatten()

    fit.initial_plot(data_axes = [axes[0], axes[1]], error_axis = axes[1], yini_list = fit.yini_list, label_ini = 'initial',
                fmin = fini, plt = plt, fig = fig,
                fontsize = cparams.fontsize)
    fit.plot_event.remove('error')

#========================================
# Optimize
#========================================
    print("")
    print("Optimize:")
    fit.print_variables(heading = "Variables")
    pfin, ffin, success, res = fit.minimize(cparams.method, jac = '3-points')
#    pfin, ffin, success, res = fit.minimize(cparams.method, jac = lambda xk: diff1(xk, fit))
    if success:
        print(f"Converged at iteration: {res.nit}")
    else:
        print(f"Function did not converge")

#========================================
# Final result
#========================================
    fit.pk = pfin.copy()
    cparams.S, cparams.ND, cparams.NA, cparams.dn, cparams.dp = pfin

    print("Final parameters")
    for i in range(nvars):
         print(f"  {fit.varname[i]:10}: {pfin[i]:10.4g} {fit.unit[i]}")
    print(f"    f={ffin:12.6g}")

    print("")
    print(f"Save parameters to [{cparams.parameterfile}]")
    cparams.save_parameters(cparams.parameterfile, section = "Preferences")

    yfin_list = fit.cal_ylist(pfin)
    print("")
    print_data(labels = [fit.xlabels[0], "y_obs", "y_cal"], 
            data_list = [fit.x_list[0], fit.y_list[0], yfin_list[0]], 
                         label_format = '{:^15}', data_format = '{:>15.4g}', header = "Final data:", nmax = 20, print_level = 0)
#    fit.print_scores(heading = "\nScores between y(input) and y(fit)", y1 = fit.y, y2 = yfin)
    print(f"Save input, initial, and final data to [{cparams.datafile}]")
    save_data(fit.datafile, fit.labels + ['ini', 'fin'], fit.x_list +  fit.y_list + fit.yini_list + yfin_list)


#=============================
# グラフの表示
#=============================
    print("")
    print("Plot optimized")
    fit.finalize_plot(yfin_list, iter = res.nit, fmin = ffin)

    fit.layout()
#    plt.tight_layout()
    plt.pause(0.01)

    print()
    input("Press ENTER to terminate>>")
    

def main():
    print()

    app, cparams = initialize()
    update_vars(app, cparams,apply_default = True)

    cparams.logfile = app.replace_path(cparams.infile)
    print(f"Open logfile [{cparams.logfile}]")
    app.redirect(targets = ["stdout", cparams.logfile], mode = 'w')
#    app.redirect(targets = ["stdout"], mode = 'a', 
#            redirect_traceback = True, output_traceback = 'stdout', display_type_traceback = 'colored')
#    app.redirect_exception()

# check args again to use given parameters
    update_vars(app, cparams,apply_default = False)

    if cparams.mode == 'fit':
        fit(app, cparams)
    else:
        app.terminate("Error in main: Invalide mode [{}]".format(cparams.mode), usage = app.usage, pause = True)

#    app.terminate("", usage = app.usage, pause = True)


if __name__ == "__main__":
    main()
