import sys
import os
import signal
import builtins
import numpy as np
from numpy import exp, log, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, sqrt, abs
import pandas as pd
from scipy.optimize import minimize
import matplotlib.pyplot as plt


from tklib.tkutils import replace_path, getarg, getintarg, getfloatarg, pint, pfloat
from tklib.tksci.tksci import kB, e
from tklib.tkvariousdata import tkVariousData
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams
from tklib.tksci.tksci import asin, acos, atan, degcos, degsin, degtan, degacos, degasin, degatan, arcsin, arccos, arctan
from tklib.tksci.tksci import factorial, log10, gamma, combination, eVTonm, nmToeV, Bn
from tklib.tksci.tksci import h, h_bar, hbar, e, kB, NA, c, pi, pi2, torad, todeg, basee
from tklib.tksci.tksci import me, mp, mn
from tklib.tksci.tksci import u0, e0, a0, R, F, g, G
from tklib.tksci.tksci import HartreeToeV, RyToeV, KToeV, eVToK, JToeV, eVToJ, Debye
from tklib.tksci.tkFit import tkFit
from tklib.tkgraphic.tkplotevent import tkPlotEvent
from tklib.tksci.tkFit import tkFit


#=======================================
# プログラム開始
#=======================================
def initialize(app):
    cparams = tkParams()
    app.cparams = cparams

    cparams.infile = ''
    cparams.model  = 'simple Arrhenius'
    cparams.Tlabel = 'T(K)'
    cparams.Plabel = 'P'
    cparams.Ttype  = 'T(K)'
    cparams.Ptype  = 'P'

    cparams.xmin = -1.0e100
    cparams.xmax =  1.0e100
    cparams.Tmin = -1.0e100
    cparams.Tmax =  1.0e100
    cparams.Tcalmin = '*'
    cparams.Tcalmax = '*'

    cparams.figsize = [8, 8]
    cparams.fontsize = 12
    cparams.legend_fontsize = 10

def update_vars(app):
    cparams = app.cparams

    app.add_argument(opt = None, type = "str", var_name = 'infile', opt_str = "infile",  desc = 'Input file', 
                     defval = "Hall-T.xlsx", optional = True);

    app.add_argument(opt = "-model", type = "str", var_name = 'model',  opt_str = "-model='simple Arrhenius'", 
                    desc = 'model: [simple Arrhenius|percolation]',
                     defval = "simple Arrhenius", optional = True);

    app.add_argument(opt = "-Tlabel", type = "str", var_name = 'Tlabel',  opt_str = "-Tlabel=T(K)",
                     desc = 'Label of T-related data in the input file',
                     defval = "T(K)", optional = True);
    app.add_argument(opt = "-Plabel", type = "str", var_name = 'Plabel',  opt_str = "-Plabel=P",
                     desc = 'Label of property-related data in the input file]',
                     defval = "P", optional = True);

    app.add_argument(opt = "-Ttype", type = "str", var_name = 'Ttype',  opt_str = "-Ttype=T(K)",
                     desc = 'Type of T-related data: [T(K)|T(C)|1/T|1000/T]',
                     defval = "T(K)", optional = True);
    app.add_argument(opt = "-Ptype", type = "str", var_name = 'Ptype',  opt_str = "-Ptype=P",
                     desc = 'Type of property-related data: [P|log10(P)|log_e(P)]',
                     defval = "P", optional = True);

    app.add_argument(opt = "-xmin", type = "double", var_name = 'xmin',  opt_str = "-xmin=-1.0e-100",
                     desc = 'Lower limit of T-related data (x) for fitting', defval = "-1.0e-100", optional = True);
    app.add_argument(opt = "-xmax", type = "double", var_name = 'xmax',  opt_str = "-xmax=1.0e-100",
                     desc = 'Upper limit of T-related data (x) for fitting', defval = "1.0e-100", optional = True);
    app.add_argument(opt = "-Tmin", type = "double", var_name = 'Tmin',  opt_str = "-Tmin=-1.0e-100",
                     desc = 'Lower limit of T for fitting', defval = "-1.0e-100", optional = True);
    app.add_argument(opt = "-Tmax", type = "double", var_name = 'Tmax',  opt_str = "-Tmax=1.0e-100",
                     desc = 'Upper limit of T for fitting', defval = "1.0e-100", optional = True);
    app.add_argument(opt = "-Tcalmin", type = "str", var_name = 'Tcalmin',  opt_str = "-Tcalmin='*'",
                     desc = 'Lower limit of T for calculation', defval = "*", optional = True);
    app.add_argument(opt = "-Tcalmax", type = "str", var_name = 'Tcalmax',  opt_str = "-Tcalmax='*'",
                     desc = 'Upper limit of T for calculation', defval = "*", optional = True);

    args_opt, args_idx, args_vars = app.read_args(vars = cparams.__dict__, check_allowed_args = True)
    if args_opt is None:
        error_no = args_idx
        error_message = args_vars
        app.terminate("\n\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n"
                   + f"!  {error_message}\n"
                   +  "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!",
                   usage = app.usage, pause = True)

#    app.set_usage(usage_str)


def execute(app):
    cparams = app.cparams

    cparams.logfile = app.replace_path(cparams.infile)
    cparams.outfile    = app.replace_path(cparams.infile, "{dirname}/{filebody}-out.xlsx")
    cparams.outfitfile = app.replace_path(cparams.infile, "{dirname}/{filebody}-fit.xlsx")
    print(f"Open logfile [{cparams.logfile}]")
    app.redict(targets = ["stdout", cparams.logfile], mode = 'w')

    cparams.xmin = pfloat(cparams.xmin)
    cparams.xmax = pfloat(cparams.xmax)
    cparams.Tmin = pfloat(cparams.Tmin)
    cparams.Tmax = pfloat(cparams.Tmax)

    print("")
    print("#======================================================")
    print("# Analyze activation energy etc by Arrhenius plot")
    print("#======================================================")
    print(f"infile : {cparams.infile}")
    print(f"outfile: {cparams.outfile}")
    print(f"logfile: {cparams.logfile}")
    print(f"model : {cparams.model}")
    print(f"Tlabel: {cparams.Tlabel}")
    print(f"Plabel: {cparams.Plabel}")
    print(f"Ttype : {cparams.Ttype}")
    print(f"Ptype : {cparams.Ptype}")
    print(f"Fitting x range: {cparams.xmin} - {cparams.xmax}")
    print(f"Fitting T range: {cparams.Tmin} - {cparams.Tmax}")

    if '***' in cparams.model:
        app.terminate("Error: Choose model", pause = True)

    print(f"")
    print(f"Read [{cparams.infile}]")
    datafile = tkVariousData(cparams.infile)
    labels, datalist = datafile.Read_minimum_matrix(close_fp = True, usage = app.usage)
    label_x, xX = datafile.FindDataArray(cparams.Tlabel, flag = 'i')
    label_y, yY = datafile.FindDataArray(cparams.Plabel, flag = 'i')
#    print("X=", label_x, xX)
#    print("Y=", label_y, yY)
    
    if cparams.Ttype == 'T(K)':
        T = xX
    elif cparams.Ttype == 'T(C)':
        T = [T + 273.15 for T in xX]
    elif cparams.Ttype == '1/T':
        T = [1.0 / T for T in xX]
    elif cparams.Ttype == '1000/T':
        T = [1000.0 / T for T in xX]
    else:
        app.terminate(f"\nInvalid Ttype [{cparams.Ttype}]", usage = app.usage, pause = True)

    if cparams.Ptype == 'P':
        P = yY
    elif cparams.Ptype == 'log10(P)':
        P = [np.power(10.0, P) for P in yY]
    elif Ttype == 'log_e(P)':
        P = [exp(P) for P in yY]
    else:
        app.terminate(f"\nInvalid Ptype [{cparams.Ptype}]", usage = app.usage, pause = True)
    print("T(K)=", T)
    print("P   =", P)

    T1000 = [1000.0 / v for v in T]
    log10P = [log(v) / log(10.0) for v in P]
#    print("1000/T (K^-1)=", T1000)
#    print("log10(P)     =", log10P)

    x_fit = []
    y_fit = []
    for i in range(len(xX)):
        
        if cparams.xmin <= xX[i] <= cparams.xmax and cparams.Tmin <= T[i] <= cparams.Tmax:
            x_fit.append(T1000[i])
            y_fit.append(log10P[i])

    if cparams.model == 'percolation':
        norder = 2
    elif cparams.model == '3rd order':
        norder = 3
    elif cparams.model == '4th order':
        norder = 4
    else:
        norder = 1
    print("")
    print(f"Least-squares fitting with {norder}-th order polynomial of 1000/T")
    print("Data to be fitted:")
    print(f"  {'1000/T':12} {'log10(P)':12}")
    for i in range(len(x_fit)):
        print(f"  {x_fit[i]:12.4g} {y_fit[i]:12.4g}")

    ci = np.polyfit(x_fit, y_fit, norder)
    print("Coefficients:")
    for i in range(norder + 1):
        idx = norder - i
        print(f"  c{i}={ci[idx]:12.4g}")

    log10Pfit = np.poly1d(ci)(x_fit)
    log10Pcal = np.poly1d(ci)(T1000)

# P(T) = P0 * exp[-e * Ea / kT + e^2 * s0^2 / 2 / kB^2 / T^2]
# log10P(T) = log10(P0) - log_10(e) * e / kB / 1000.0 * Ea * (1000 / T) + log_10(e) * e^2 / kB^2 / 1000^2 * s0^2 / 2 * (1000 / T)^2
# c1 = -log10(e) * e / kB / 1000 * Ea: Ea = -c1 * ln(10) * kB / e * 1000.0
# c2 = log10(e) * (e / kB / 1000)^2 /2 * s0^2: s0 = sqrt(c2 * ln(10) * (1000 kB/e)^2 * 2)
    P0 = np.power(10.0, ci[norder])
    Ea = -ci[norder-1] * kB / e * 1000.0 * log(10.0)
    print("")
    print(f"  P0={P0:12.4g}")
    print(f"  Ea=phi_0={Ea:12.4g} eV")
    if norder >= 2:
        sigma_phi = ci[norder-2] * (1000.0 * kB / e)**2 * 2.0 * log(10.0)
        if sigma_phi < 0.0:
            print(f"\n  ***Warning: sigma_phi^2 is negative: {sigma_phi:12.4g} eV^2")
        else:
            sigma_phi = sqrt(sigma_phi)
            print(f"  sigma_phi={sigma_phi:12.4g} eV")

    fit = tkFit()
    P_fit = [pow(10.0, log10Pfit[i]) for i in range(len(x_fit))]
    fit.print_scores(heading = "\nScores between P(input) and P(fit)", y1 = P, y2 = P_fit)
    fit.print_scores(heading = "\nScores between log10(P(input)) and log10(P(fit))", y1 = log10P, y2 = log10Pfit)

    print("")
    print( "Calculate data from the fitting result")
    Tcalmin = pfloat(cparams.Tcalmin, defval = min(T))
    Tcalmax = pfloat(cparams.Tcalmax, defval = max(T))
    nT = 101
    Tstep = (Tcalmax - Tcalmin) / (nT - 1)
    print(f"  T range: {Tcalmin:8.3f} - {Tcalmax:8.3f} K")
    T_plot = [Tcalmin + i * Tstep for i in range(nT)]
    x_plot = [1000.0 / T_plot[i] for i in range(nT)]
    y_plot = np.poly1d(ci)(x_plot)
    P_plot = [pow(10.0, y_plot[i]) for i in range(nT)]
    diff_plot = [(y_plot[1] - y_plot[0]) / (x_plot[1] - x_plot[0])]
    for i in range(1, nT-1):
        diff_plot.append((y_plot[i+1] - y_plot[i-1]) / (x_plot[i+1] - x_plot[i-1]))
    diff_plot.append((y_plot[nT-1] - y_plot[nT-2]) / (x_plot[nT-1] - x_plot[nT-1]))
    Ea_plot = [-diff_plot[i] * kB / e * 1000.0 * log(10.0) for i in range(nT)]

    """
    print("")
    print("Data")
    print(f"{'idx':5} {'X':12} {'Y':12} {'T':8} {'P':12} {'1000/X':12} {'log10(P)':12} {'log10(P)(cal)':12}")
    for i in range(len(xX)):
        print(f"{i:5} {xX[i]:12.4g} {yY[i]:12.4g} {T[i]:8.3g} {P[i]:12.4g} {T1000[i]:12.4g} {log10P[i]:12.4g} {log10Pcal[i]:12.4g}")
    """

    print("")
    print(f"Save to [{cparams.outfile}]")
    fit.to_excel(cparams.outfile, [label_x, label_y, 'T(K)', 'P', '1000/T (K^-1)', 'log10(P)', 'log10(P)(cal)',
                                   '', 'T (K)', '1000/T (K^-1)', 'P(cal)', 'log10(P)(cal)'],
                                  [xX, yY, T, P, T1000, log10P, log10Pcal, [], T_plot, x_plot, P_plot, y_plot])

#=========================================
# plot
#=========================================
    fig, axes = plt.subplots(2, 2, figsize = cparams.figsize)
    plot_event = tkPlotEvent(plt)

    axes = axes.flatten()
    axes[0].tick_params(labelsize = cparams.fontsize)
    axes[1].tick_params(labelsize = cparams.fontsize)
    axes[2].tick_params(labelsize = cparams.fontsize)
    axes[3].tick_params(labelsize = cparams.fontsize)

    XY_data = axes[0].plot(xX, yY, linestyle = '', marker = 'o', markerfacecolor = 'black', markersize = 5.0)
    axes[0].set_xlabel(label_x, fontsize = cparams.fontsize)
    axes[0].set_ylabel(label_y, fontsize = cparams.fontsize)

    TP_data = axes[1].plot(T, P, linestyle = '', marker = 'o', markerfacecolor = 'black', markersize = 5.0)
    TP_fit_data = axes[1].plot(T_plot, P_plot, linestyle = '-', color = 'red', linewidth = 0.5)
    axes[1].set_xlabel('$T$ (K)', fontsize = cparams.fontsize)
    axes[1].set_ylabel('$P$',     fontsize = cparams.fontsize)

    arr_data  = axes[2].plot(T1000,  log10P, label = 'raw data',         linestyle = '', marker = 'o', markerfacecolor = 'black', markersize = 5.0)
    fit_data  = axes[2].plot(x_fit,  y_fit,  label = 'raw data (fitted)', linestyle = '', marker = 'o', markerfacecolor = 'red', markersize = 8.0)
    plot_data = axes[2].plot(x_plot, y_plot, label = 'fitted', color = 'red', linewidth = 0.5)
    axes[2].set_xlabel('$1000/T$ (K$^{-1}$)', fontsize = cparams.fontsize)
    axes[2].set_ylabel('$log_{10}(P)$',       fontsize = cparams.fontsize)
    axes[2].legend(fontsize = cparams.legend_fontsize)

    Ea_data  = axes[3].plot(x_plot, Ea_plot, label = 'Ea (eV)', linestyle = '-', linewidth = 0.5, color = 'black')
    axes[3].set_xlabel('$1000/T$ (K$^{-1}$)', fontsize = cparams.fontsize)
    axes[3].set_ylabel('$E_a$ (eV)',          fontsize = cparams.fontsize)
    axes[3].legend(fontsize = cparams.legend_fontsize)

    all_data = datalist
    plot_event.add_data({"label": "X-Y plot",       "plot_type": "2D", "axis": axes[0], "data": XY_data,
                    "xlist": all_data, "xlabels": labels})
    plot_event.add_data({"label": "T-P plot",       "plot_type": "2D", "axis": axes[1], "data": TP_data,
                    "xlist": all_data, "xlabels": labels})
#    plot_event.add_data({"label": "fitted",         "plot_type": "2D", "axis": axes[1], "data": TP_fit_data,
#                    "xlist": all_data, "xlabels": labels})
    plot_event.add_data({"label": "Arrhenius plot", "plot_type": "2D", "axis": axes[2], "data": arr_data,
                    "xlist": all_data, "xlabels": labels})
#    plot_event.add_data({"label": "fitted",         "plot_type": "2D", "axis": axes[2], "data": plot_data})
    plot_event.add_data({"label": "Ea",             "plot_type": "2D", "axis": axes[3], "data": Ea_data})
    plot_event.register_event(fig, event = "button_press_event", 
                    callback = lambda event: plot_event.onclick(event))

    plt.tight_layout()
    plt.pause(0.001)

#    app.terminate("", usage = app.usage, pause = True)
    app.terminate("", usage = None, pause = True)


#==========================================
# Main routine
#==========================================
def main():
    app = tkApplication()

    print(f"Initialize parameters")
    initialize(app)
    print(f"Update parameters by command-line arguments")
    update_vars(app)

    execute(app)


if __name__ == "__main__":
    main()

