import os
import sys
import csv
import numpy as np
from numpy import sin, cos, tan, pi, exp, log, sqrt
#from scipy.optimize import minimize
#import pandas as pd
from matplotlib import pyplot as plt


from tklib.tksci.tksci import log10, e, kB
from tklib.tkutils import getarg, getintarg, getfloatarg, pconv_by_type, print_line, print_data, joinf
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams
from tklib.tkvariousdata import tkVariousData
#from tklib.tkfilter import tkFilter
from tklib.tkgraphic.tkplotevent import tkPlotEvent
from tklib.tksci.tkFit_mxy import tkFit_mxy

from tklib.tksci.tkFit_lib import conv_input, save_data, add_history
from minimize_func_DOS import initialize_minimize_func, read_data, ycal_list, save_parameters, minimize_func, callback


"""
Fitting by SIMPLEX and Gaussian Process regressions
"""

ProgramName = "DOSFit_flex"


def initialize():
    app          = tkApplication()
    argv, narg   = app.get_argv()
    app.cparams  = tkParams()
    cparams      = app.cparams
    cparams.debug       = 0
    cparams.print_level = 0

    initialize_minimize_func()

    app.add_argument(opt = None, type = "str", var_name = 'file1', defval = '')
    app.add_argument(opt = None, type = "str", var_name = 'file2', defval = '')
    app.add_argument(opt = None, type = "str", var_name = 'file3', defval = '')

    app.add_argument(opt = "--mode", type = "str", var_name = 'mode',  opt_str = "--mode=[init|sim|fit|plot]", desc = 'task mode',
                     defval = 'fit', optional = True)

    app.add_argument(opt = "--fhistory", type = "int", var_name = 'fhistory',  opt_str = "--fhisotry=[0|1]", desc = 'flag to save hisotry file',
                     defval = 0, optional = True)
    app.add_argument(opt = "--ffitfiles", type = "int", var_name = 'ffitfiles',  opt_str = "--ffitfiles=[0|1]", desc = 'flag to save fit files',
                     defval = 0, optional = True)
    app.add_argument(opt = "--fplot", type = "int", var_name = 'fplot',  opt_str = "--fplot=[0|1]", desc = 'flag to plot graph',
                     defval = 1, optional = True)

    app.add_argument(opt = '--infile', type = "str", var_name = 'infile', opt_str = "--infile=path",  desc = 'reference (observed IV) Excel file', 
                     defval = 'DOS.xlsx', optional = True)
    app.add_argument(opt = "--outfile", type = "str", var_name = 'outfile', opt_str = "--outfile=path",  desc = 'output Excel file', 
                     defval = 'DOS-input.xlsx', optional = True)
    app.add_argument(opt = "--fitfile", type = "str", var_name = 'fitfile', opt_str = "--fitfile=path",  desc = 'Fitting result (.xlsx) file', 
                     defval = "DOS-fit.xlsx", optional = True)

    app.varname    = [ "ge1",    "gn1", "gw1", "ge2",    "gn2", "gw2"]
    app.unit       = [  "eV", "/cm/eV",  "eV",  "eV", "/cm/eV",  "eV"]
    app.pk_convert = [    "",       "",    "",    "",       "",    ""]
    app.optid      = [     1,        1,     1,     1,        1,     1]
    app.x0         = [  0.15,   2.5e16,   0.1,   0.4,   5.0e15,  0.15]
    app.dx         = [  0.05,   1.0e16,  0.02,  0.05,   1.0e15,  0.02]
    app.kmin       = [  0.00,   1.0e14, 0.001,   0.3,   1.0e12,  0.02]
    app.kmax       = [   0.3,   1.0e18,   0.3,   0.5,   1.0e17,   0.5]
    app.kpenalty   = [   1.0,      1.0,   1.0,   1.0,      1.0,   1.0]
    app.y_convert = "log"

    for i, varname in enumerate(app.varname):
        app.add_argument(opt = f"--{varname}", type = "float", defval = app.x0[i])

    app.add_argument(opt = "--method", type = "str", var_name = 'method',  opt_str = "--method=[nelder-mead]", desc = 'optimization algorism',
                     defval = 'nelder-mead', optional = True)
    app.add_argument(opt = "--jac", type = "str", var_name = 'jac',  opt_str = "--jac=[3-points|2-points|func]", desc = 'first differential',
                     defval = '2-points', optional = True)
    app.add_argument(opt = "--nmaxiter", type = "int", var_name = 'nmaxiter',  opt_str = "--nmaxiter=int(>1)", desc = 'maximum interation number for optimization',
                     defval = 1000, optional = True)
    app.add_argument(opt = "--tol", type = "float", var_name = 'tol',  opt_str = "--tol=val", desc = 'eps for optimization',
                     defval = 1.0e-5, optional = True)

    app.add_argument(opt = "--print_interval", type = "int", var_name = 'print_interval',  opt_str = "--print_interval=val", desc = 'print interval',
                     defval = 5, optional = True)
    app.add_argument(opt = "--plot_interval", type = "int", var_name = 'plot_interval',  opt_str = "--plot_interval=val", desc = 'plot interval',
                     defval = 5, optional = True)

#=============================
# Graph configuration
#=============================
    cparams.figsize             = [8, 6]
    cparams.fontsize            = 16
    cparams.legend_fontsize     = 12
    cparams.graphupdateinterval = 10

    return app, cparams


#=============================
# Treat argments
#=============================
def update_vars(app, cparams, apply_default = True):
    args_opt, args_idx, args_vars = app.read_args(vars = cparams, check_allowed_args = True, apply_default = apply_default)
    if args_opt is None:
        error_no = args_idx
        error_message = args_vars
        app.terminate("\n\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
                   + f"!  {error_message}\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",
                   usage = app.usage)

    return args_opt, args_idx, args_vars

#==========================================
# functions
#==========================================
def init_fit(app, cparams):
    cparams.parameterfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}.in"])
    cparams.historyfile   = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-history.xlsx"])
    cparams.stopfile      = app.replace_path(cparams.infile, template = ["{dirname}", "stop"])
#    cparams.parameterbkfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-back.in"])

#Parametersセクション以外を読み込み
    print(f"Read parameters from [{cparams.parameterfile}]")
    cparams.read_parameters(cparams.parameterfile, section = "Preferences",
                ignore_keys = ["logfile", "outfile", "parameterfile", "parameterbkfile"])
    cparams.read_parameters(cparams.parameterfile, section = "Flags",
                ignore_keys = ["logfile", "outfile", "parameterfile", "parameterbkfile"])
    cparams.read_parameters(cparams.parameterfile, section = "Condition",
                ignore_keys = ["logfile", "outfile", "parameterfile", "parameterbkfile"])
    app.force_given_args(cparams)
#    cparams.print_parameters(heading = "\ncparams:")
    print("mode                 : {}".format(cparams.mode))
    print("fplot                : {}".format(cparams.fplot))
    print("fhistory             : {}".format(cparams.fhistory))
    print("ffitfiles            : {}".format(cparams.ffitfiles))
    print("stop path            : {}".format(cparams.stopfile))
    print("Input DOS infile     : {}".format(cparams.infile))
    print("Fitting result file  : {}".format(cparams.fitfile))
    print("outfile              : {}".format(cparams.outfile))
    print("parameter file       : {}".format(cparams.parameterfile))
#    print("parameter backup file: {}".format(cparams.parameterbkfile))
    print("history file         : {}".format(cparams.historyfile))
    print()
    print(f"Fitting configuration")
    print(f"  method  : {cparams.method}")
    print(f"  jac     : {cparams.jac}")
    print(f"  tol     : {cparams.tol}")
    print(f"  nmaxiter: {cparams.nmaxiter}")
    
    fit = tkFit_mxy(method = cparams.method, tol = cparams.tol, nmaxiter = cparams.nmaxiter,
                fplot = cparams.fplot, print_interval = cparams.print_interval, plot_interval = cparams.plot_interval)
    fit.configure(print_level = 0, cparams = cparams)
    fit.copy_attributes(cparams, ["infile", "outfile", "historyfile", "fitfile"])

# fitのフィッティング変数・条件の定義
    fit.copy_attributes(app, ["varname", "unit", "pk_convert", "optid", "dx", "kmin", "kmax", "kpenalty", "y_convert"])
    x0 = fit.build_parameter_list(fit.varname, source_obj = cparams)
# 起動時引数のパラメターを設定
    fit.pk         = x0.copy()
    nvars = len(fit.pk)
# iniファイルのパラメターを設定
    fit.read_parameters(cparams.parameterfile, section = "Parameters", keys = fit.varname)
#    fit.print_variables()

    fit.initial_simplex = None
    if fit.dx:
        fit.initial_simplex = fit.build_initial_simplex()
        for i, splx in enumerate(fit.initial_simplex):
            print(f"  {i}: ", end = '')
            for v in splx:
                print(f" {v:12.4g}", end = '')
            print()

# fitから呼び出される関数の定義
# xsはxの値のセット
# x_list[i][j] はi番目のx変数のj番目の値
    if cparams.jac == 'func':
        fit.jac = lambda xk: diff1(xk, fit)
    else:
        fit.jac = cparams.jac
    fit.configure(cal_ylist = lambda pk, run = True: ycal_list(None, pk, fit, run = run),
                  minimize_func = lambda pk: minimize_func(pk, fit.x_list, fit.y_list, fit.w_list, fit),
                  callback = lambda pk: callback(pk, fit)
                 )

# fitのグラフ関連変数の定義
    fit.configure(fplot = fit.fplot, plt = plt, iter = 0, xiter = [], yfmin = [], ycal_list = None, yfin_list = None)

    fit.print_variables(heading = "Fitting parameters:")

    return fit

def init(app, cparams):
    print("")
    print("Make ini file")
    fit = init_fit(app, cparams)

    cparams.parameterfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}.in"])
    save_parameters(fit, cparams)

    input("\nPress ENTER to terminate>>")
    print()

def plot(app, cparams):
    cparams.historyfile   = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-history.xlsx"])

    if cparams.file1 == '':
        if os.path.isfile(cparams.fitfile):
            filetype = 'fitfile'
            cparams.file1 = cparams.fitfile
        elif os.path.isfile(cparams.outfile):
            filetype = 'inputfile'
            cparams.file1 = cparams.outfile
    elif cparams.file1 == 'history':
        filetype = 'history'
        cparams.file1 = cparams.historyfile
    elif cparams.file1 == 'input':
        filetype = 'inputfile'
        cparams.file1 = cparams.outfile
    elif cparams.file1 == 'fit':
        filetype = 'fitfile'
        cparams.file1 = cparams.fitfile

    print("")
    print("Plot DOS")
    print(f"file1   : {cparams.file1}")
    print(f"filetype: {filetype}")

    print()
#    print(f"Read file1 from [{cparams.file1}]")
    fit.labels, fit.x_list, fit.y_list = read_data(cparams.file1, filetype = filetype)
    if filetype == 'history':
        fit.labels = fit.labels[0:2]
        fit.y_list = [fit.y_list[0]]
    nx = len(fit.x_list)
    ny = len(fit.y_list)
    print("nx:", nx)
    print("ny:", ny)
    xlabels = fit.labels[0:nx]
    ylabels = fit.labels[nx:]
    print("xlabels:", xlabels)
    print("ylabels:", ylabels)

    print("plot")
    fig, axes = plt.subplots(1, 1, figsize = cparams.figsize)
    ax = axes
    ax.tick_params(labelsize = cparams.fontsize)
    width   = [1.0, 0.5, 0.5, 0.5, 0.5, 0.5]
    markers = ['o', 's', '^', '<', '>', '+']
    msizes  = [5.0, 3.0, 3.0, 3.0, 3.0, 3.0]
    if filetype == 'history':
        msizes = [0.3]
    for i in range(ny):
        ax.plot(fit.x_list[0], fit.y_list[i], label = fit.labels[i+1], linewidth = width[i], marker = markers[i], markersize = msizes[i])
    if filetype == 'history':
        ax.set_yscale('log')
    ax.set_xlabel(fit.labels[0], fontsize = cparams.fontsize)
    ax.set_ylabel(fit.labels[1], fontsize = cparams.fontsize)
    ax.legend(fontsize = cparams.legend_fontsize)

    plt.pause(0.01)
    input("\nPress ENTER to terminate>>")
    print()

def sim(app, cparams):
    print("")
    print("Perform DOS calculation")
    fit = init_fit(app, cparams)

    print()
    print(f"Read input data from [{cparams.infile}]")
    fit.labels, fit.x_list, fit.y_list = read_data(cparams.infile)
    print(f"Save input data repeat to [{fit.outfile}]")
    save_data(fit.outfile, fit.labels, fit.x_list + fit.y_list)

    print()
    print("Configure fit (tkFit_mxy object):")
    fit.configure(xlabels = [fit.labels[0]], ylabels = [fit.labels[1]], w_list = None)

    optpk = fit.extract_parameters()
    minimize_func(optpk, fit = fit, run = False)
    fit.yini_list = fit.yc_list

    print_data(labels = fit.labels + ['ini'], 
            data_list = fit.x_list + fit.y_list + fit.yini_list,
            label_format = '{:^15}', data_format = '{:>15.4g}', header = "Simulated IV:", nmax = 20, print_level = 0)

    print(f"Save input and initial data to [{fit.fitfile}]")
    save_data(fit.fitfile, ['E', 'DOS(in)', 'DOS(ini)'], [fit.x_list[0], fit.y_list[0], fit.yini_list[0]])

    if fit.fplot:
        print("plot")
        fig, axes = plt.subplots(1, 1, figsize = cparams.figsize)
        ax = axes
        ax.tick_params(labelsize = cparams.fontsize)
        ax.plot(fit.x_list[0], fit.y_list[0],    label = 'measured',  linestyle = '', marker = 'o', markersize = 1.5)
        ax.plot(fit.x_list[0], fit.yini_list[0], label = 'simulated', linewidth = 0.5)
        ax.set_xlabel(fit.labels[0], fontsize = cparams.fontsize)
        ax.set_ylabel(fit.labels[1], fontsize = cparams.fontsize)
        ax.legend(fontsize = cparams.legend_fontsize)
        plt.pause(0.01)
        input("\nPress ENTER to terminate>>")
        print()

def fit(app, cparams):
    print("")
    print("Fitting to TFT simulation")
    fit = init_fit(app, cparams)

    print()
    print(f"Read input data from [{cparams.infile}]")
    fit.labels, fit.x_list, fit.y_list = read_data(cparams.infile)
    print(f"Save input data repeat to [{fit.outfile}]")
    save_data(fit.outfile, fit.labels, fit.x_list + fit.y_list)

    print()
    print("Configure fit (tkFit_mxy object):")
    fit.configure(xlabels = [fit.labels[0]], ylabels = [fit.labels[1]], w_list = None)

    print("")
    print(f"Calculate initial IV curve")
    optpk = fit.extract_parameters()
    fini = fit.minimize_func(optpk)
    fit.yini_list = fit.yc_list

    print(f"Initial function: fmin={fini:10.4g}")
    print_data(labels = ["x", "y_obs", "y_ini"], 
            data_list = [fit.x_list[0], fit.y_list[0], fit.yini_list[0]], 
            label_format = '{:^15}', data_format = '{:>15.4g}', header = "Initial data:", nmax = 20, print_level = 0)
    print("fini=", fini)

    print(f"Save input and initial data to [{fit.fitfile}]")
    save_data(fit.fitfile, ['E', 'DOS(in)', 'DOS(ini)'], [fit.x_list[0], fit.y_list[0], fit.yini_list[0]])
    
#========================================
# plot
#========================================
    if fit.fplot:
        fig, axes = plt.subplots(len(fit.y_list), 2, figsize = cparams.figsize)
#    axes = [axes]
#    axes = axes.flatten()

        fit.initial_plot(data_axes = [axes[0], axes[1]], error_axis = axes[1], yini_list = fit.yini_list, label_ini = 'initial',
                    fmin = fini, plt = plt, fig = fig,
                    fontsize = cparams.fontsize)
        fit.plot_event.remove('error')

#========================================
# Optimize
#========================================
    print("")
    print("Optimize:")
    fit.print_variables(heading = "Variables")
    pfin, ffin, success, res = fit.minimize(cparams.method, jac = fit.jac, initial_simplex = fit.initial_simplex)
    fit.pk = pfin.copy()
    if success:
        print(f"Converged at iteration: {res.nit}")
    else:
        print(f"Function did not converge")

#========================================
# Final result
#========================================
    fit.retrieve_parameter_list(pfin, fit.varname, target = fit.cparams)

    fit.print_variables(heading = "Final parameters:", fmin = ffin)

    print("")
    print(f"Save parameters to [{cparams.parameterfile}]")
    save_parameters(fit, cparams)

    yfin_list = fit.cal_ylist(pfin, run = True)
    print("")
    print_data(labels = [fit.xlabels[0], "y_obs", "y_cal"], 
            data_list = [fit.x_list[0], fit.y_list[0], yfin_list[0]], 
                         label_format = '{:^15}', data_format = '{:>15.4g}', header = "Final data:", nmax = 20, print_level = 0)
#    fit.print_scores(heading = "\nScores between y(input) and y(fit)", y1 = fit.y, y2 = yfin)
    print(f"Save input, initial, and final data to [{cparams.fitfile}]")
    save_data(fit.fitfile, ["E", "DOS(in)", "DOS(ini)", "DOS(fin)"], [fit.x_list[0], fit.y_list[0], fit.yini_list[0], yfin_list[0]])


#=============================
# グラフの表示
#=============================
    if fit.fplot:
        print("")
        print("Plot optimized")
        fit.finalize_plot(yfin_list, iter = res.nit, fmin = ffin)

        fit.layout()
#       plt.tight_layout()
        plt.pause(0.01)

    print()
    input("Press ENTER to terminate>>")


def main():
    print()
    print( "#========================================================")
    print(f"#  {ProgramName}")
    print( "#========================================================")

    app, cparams = initialize()
    update_vars(app, cparams,apply_default = True)

    cparams.logfile = app.replace_path(cparams.infile)
    print(f"Open logfile [{cparams.logfile}]")
    app.redirect(targets = ["stdout", cparams.logfile], mode = 'w')
#    app.redirect(targets = ["stdout"], mode = 'a', 
#            redirect_traceback = True, output_traceback = 'stdout', display_type_traceback = 'colored')
#    app.redirect_exception()

# check args again to use given parameters
    update_vars(app, cparams,apply_default = False)

    if cparams.mode == 'init':
        init(app, cparams)
    elif cparams.mode == 'sim':
        sim(app, cparams)
    elif cparams.mode == 'fit':
        fit(app, cparams)
    elif cparams.mode == 'plot':
        plot(app, cparams)
    else:
        app.terminate("Error in main: Invalide mode [{}]".format(cparams.mode), usage = app.usage, pause = True)

#    app.terminate("", usage = app.usage, pause = True)


if __name__ == "__main__":
    main()
