import numpy as np
from math import sqrt, pow, exp, log
import os
import sys
import csv
import pandas as pd
import openpyxl
from matplotlib import pyplot as plt
from matplotlib import gridspec


from tklib.tkutils import terminate, pfloat, pint, getarg, getintarg, getfloatarg
from tklib.tkapplication import tkApplication
from tklib.tkvariousdata import tkVariousData
from tklib.tktransport.tkWeightedMobility import weighted_mobility, weighted_mobility_new, weighted_mobility_exact
from tklib.tktransport.tkDOS_FEA import tkDOS
from tklib.tktransport.tkmobility_tau import tkMobility, split_optstr


#===================================
# physical constants
#===================================
pi   = 3.14159265358979323846
e    = 1.60218e-19      # C";
kB   = 1.380658e-23     # JK<sup>-1</sup>";
me          = 9.1093897e-31    # kg";
h           = 6.62607015e-34   # Js";
h_bar       = 1.05457203e-34   # "Js";
hbar        = h_bar

kBe = kB / e             #8.6173303e-5      # kB/e


# Global variables
infile  = "pisarenko-plot-all-data-STO.xlsx"
outfile = None

T_label = 0
S_label = 1
N_label = 3
rfac = 0.0


nmaxiter = 100
eps = 1.0e-10
a =  2000 #max
b = -300 #min


#=============================
# Graph configuration
#=============================
fig = None
figsize  = (6, 6)
fontsize = 18
legend_fontsize = 12

app = None

#===================================
# Treat arguments
#===================================
infile  = getarg(1, defval = infile)
T_label = getarg(2, defval = T_label)
S_label = getarg(3, defval = S_label)
N_label = getarg(4, defval = N_label)
rfac    = getfloatarg(5, defval = rfac)

header, ext = os.path.splitext(infile)
filebody    = os.path.basename(header)
outfile     = f'{header}-meff-out.xlsx'


#===================================
# Other functions
#===================================
def usage(app):
    argv = sys.argv
    print("")
    print("Usage: python {} infile T_label S_label(V/K) N_label(cm^-3) r".format(argv[0]))
    print("   ex: python {} {} {} {} {} {}".format(argv[0], 
                    infile, T_label, S_label, N_label, rfac))

def savecsv(outfile, header, datalist):
    try: 
        print("Write to [{}]".format(outfile))
        f = open(outfile, 'w')
    except:
#    except IOError:
        print("line 98 Error: Can not write to [{}]".format(outfile))
    else:
        fout = csv.writer(f, lineterminator='\n')
        fout.writerow(header)
#        fout.writerows(data)
        for i in range(0, len(datalist[0])):
            a = []
            for j in range(len(datalist)):
                a.append(datalist[j][i])
            fout.writerow(a)
        f.close()

def read_file(fname):
    print("")

    datafile = tkVariousData(infile)
    labels, datalist = datafile.Read_minimum_matrix(close_fp = True, usage = usage)

    return datafile, labels, datalist


def main():
    global app
    
    app = tkApplication()
    logfile = app.replace_path(infile)
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    print("")
    print("input file :", infile)
    print("output file:", outfile)
    print("T_label    :", T_label)
    print("S_label    :", S_label)
    print("N_label    :", N_label)
    print("rfac       :", rfac)
    print("")
    
    if T_label is str and '***' in T_label:
        app.terminate("Error: Choose T", usage = usage, pause = True)
    if S_label is str and '***' in S_label:
        app.terminate("Error: Choose S", usage = usage, pause = True)
    if N_label is str and '***' in N_label:
        app.terminate("Error: Choose N", usage = usage, pause = True)

    if S_label is str and ('uK' in S_label or 'microK' in S_label):
        app.terminate(f"Error: The unit of S must be 'V/K'. The given label is [{S_label}]", usage = usage, pause = True)

    print("Read data from [{}]".format(infile))
    datafile, labels, datalist = read_file(infile)
    T_idx = datafile.FindLabelIndex(T_label, flag = 'i')
    S_idx = datafile.FindLabelIndex(S_label, flag = 'i')
    N_idx = datafile.FindLabelIndex(N_label, flag = 'i')
    print("T index: ", T_idx)
    print("S index: ", S_idx)
    print("N index: ", N_idx)
    Tlabel, Tlist         = datafile.FindDataArray(T_label, flag = 'i')
    Slabel, Slist         = datafile.FindDataArray(S_label, flag = 'i')         # V/K
    Nlabel, Nlist         = datafile.FindDataArray(N_label, flag = 'i')         # cm-3

    Estep = 0.01
    nrange = 6.0

    charge = 1.0
#    rfac = 0.0
#    rfac = 0.5
#    rfac = 2.0
    mu_e = tkMobility(charge = charge, meff = None, rfac = rfac, l0 = 1.0e-10)  # l0 in m
    dos = tkDOS(EV = 0.0, EC = 1.1, EA = 0.0, NA = 0.0, ED = 0.0, ND = 0.0, EF0 = 0.0, dEFmin = -0.2, dEFmax = 2.0,
                meeff = None, mheff = None)
    klatt = 5.0

    print("")
    print(f"{'T(K)':8}  {'S(V/K)':10}  {'N(cm^-3)':10}")
    for i in range(len(Tlist)):
        print(f"{Tlist[i]:8.3g}  {Slist[i]:10.4g}  {Nlist[i]:10.4g}")

    meff_list = []
    for ic in range(len(Tlist)):
        T = Tlist[ic]
        S = Slist[ic]
        N = Nlist[ic]
        print(f"{ic:3d}: {T=} K  {S=} V/K  {N=} cm^-3")

        SkBe = S / kBe
        a = SkBe - 1.0 / SkBe
        if SkBe - 2.0 > 100.0:
            expSkBe2 = 1.0e300
        else:
            expSkBe2 = exp(SkBe - 2.0)
        if expSkBe2 - 0.17 >= 0.0:
            A1 = 3.0 * pow(expSkBe2 - 0.17, 2.0/3.0)
            A2 = 1.0 + exp(-5.0 * a)
            B2 = 1.0 + exp(5.0 * a)
            meff = 0.924 * (300.0 / T) * pow(N / 1.0e20, 2.0/3.0) * (A1 / A2 + SkBe / B2)

#katase
#=(($R$3^2)/(2*$M$3*300)*((3*E2)/(16*PI()^0.5))^(2/3)*(((EXP((F2/1000000)/$M$5-2)-0.17)^(2/3))/(1+EXP(-5*((F2/1000000)/$M$5-$M$5/(F2/1000000))))+(3/PI()^2*(2/PI()^0.5)^(2/3)*(F2/1000000)/$M$5)/(1+EXP(5*((F2/1000000)/$M$5-$M$5/(F2/1000000))))))/$R$4*1000
#=((h^2)/(2*kB*300)*((3*n)/(16*pi^0.5))^(2/3)*(((EXP((S/1000000)/kB/e-2)-0.17)^(2/3))/(1+EXP(-5*((S/1000000)/kB/e-kB/e/(S/1000000))))+(3/pi^2*(2/pi^0.5)^(2/3)*(S/1000000)/kB/e)/(1+EXP(5*((S/1000000)/kB/e-kB/e/(S/1000000))))))/me0*1000
            S = abs(S)
#            SkBe = S / (kB / e)
            A = h*h/(2*kB*T) * pow(3.0 * N * 1.0e6 / 16.0 / sqrt(pi), 2.0/3.0)
            B1 = pow(exp((S/(kB/e)-2)-0.17), 2.0/3.0)
            B2 = 1.0 + exp(-5.0 * (SkBe - 1.0 / SkBe))
            C1 = 3/pi/pi * pow(2.0/sqrt(pi), 2.0/3.0) * SkBe
            C2 = 1 + exp(5.0 * (SkBe - 1.0 / SkBe))
            meff2 = A / me * (B1/B2 + C1/C2)

            meff3 = 0.857 * (300.0 / T) * pow (N / 1.0e20, 2.0/3.0) \
                  * (3.0 * B1 / B2 + SkBe / C2)

        else:
            print()
            print(f"Warning: exp(|S|/(kB/e) - 2) - 0.17 < 0: S is out of range")
            print()
#            meff = 0.0
            meff_list.append(None)
            continue

        print(f"T    = {T} K")
        print(f"S    = {S*1.0e6} uV/K")
        print(f"N    = {N} cm^-3")
        print(f"m* estimated from Snyder et al., Adv. Funct. Mater. 32, 2112772 (2022)")
        print(f"  meff(abst) = {meff:10.6g} me")
        print(f"  meff(eq.3) = {meff2:10.6g} me")
        print(f"  meff(eq.4) = {meff3:10.6g} me")
        print(f"  Validation:")

        dE = nrange * kB * T / e
        E0 = dos.EV - dE
        E1 = dos.EC + dE

        for _meff in [meff, meff2, meff3]:
            r = mu_e.rfac
            mu_e.meff = _meff
            dos.set_meeff(mu_e.meff, T)
            dos.set_mheff(mu_e.meff, T)

            Sndeg = dos.cal_S_nondegenerated_from_Ne(n = N, rfac = r, charge = mu_e.charge)
            Sdeg  = dos.cal_S_degenerated_from_Ne(T = T, n = N, rfac = r, charge = mu_e.charge)

            EFdeg = dos.EF0K_from_N_meff(N = N, meff = dos.meeff)

            EF, diffEF, ret = dos.EF_from_electrondensity(N, T, EF0 = dos.EF0, dEF = 0.01, dump = 0.0, epsEF = 1.0e-5, maxiter = 100)
            xe  = (EF - dos.EC) * e / kB / T
            sigmae, ne, mue, tau_avge, Se, kappae, kappa_tote, Le, PFe, ZTe, infe = \
                dos.cal_transport_S(xe, T, meff, dos.NC, mu_e, klatt = klatt, validate_error_str = None, charge = mu_e.charge)

            print(f"    me*: {mu_e.meff:10.6g} me")
#            print(f"      EF-EC: {EF-dos.EC:10.6g} eV   diffEF: {diffEF:10.6g} eV")
#            print(f"      Ne   : {ne:10.6g} cm^-3")
#            print(f"      S    : {Se:10.6g} uV/K")
#            print(f"      EFdeg: {EFdeg:10.6g} eV")
#            print(f"      Sdeg : {Sdeg*1e6:10.6g} uV/K")
#            print(f"      Sndeg: {Sndeg*1e6:10.6g} uV/K")
#            print(f"    sigma: {sigmae:10.6g}")

            xe  = EFdeg * e / kB / T
            sigmae, ne, mue, tau_avge, Se, kappae, kappa_tote, Le, PFe, ZTe, infe = \
                dos.cal_transport_S(xe, T, meff, dos.NC, mu_e, klatt = klatt, validate_error_str = None, charge = mu_e.charge)
            print(f"      S from EFdeg: {Se:10.6g} uV/K")

# S in V/K
            EF_S = dos.EF_from_S(T, abs(S), mu_e.rfac, polarity = 'h', EFmin = -2.0, EFmax = 1.0, print_level = 0, eps = eps, nmaxiter = nmaxiter)
            if EF_S is None:
                app.terminate(f"Error in S2m::main(): EF calculation did not converge from S={S:10.4g} V/K and r={mu_e.rfac} at T={T:8.3g} K.", pause = True)

            print(f"      Calculate EF from S: {EF_S:10.6g} eV for S={S:10.6g} V/K")
            meff_S = dos.meeff_from_Ne_EF0K(N, EF_S)   # N in cm^-3
            print(f"        m* from EF_S: {meff_S:10.6g} me")


        print()
        meff_list.append(meff)

    df = pd.DataFrame(np.array([Tlist,  Nlist,      Slist,    meff_list]).T,
                     columns = ["T(K)", "N(cm^-3)", "S(V/K)", "meff(me)"])
    print("")
    print("Save meff data to [{}]".format(outfile))
    df.to_excel(outfile, index = False, header = True)


#=============================
# グラフの表示
#=============================
    print("")

    fig = plt.figure(figsize = figsize)
    spec = gridspec.GridSpec(ncols = 1, nrows = 2, height_ratios=[1, 1])
    ax1 = fig.add_subplot(spec[0])
    ax2 = fig.add_subplot(spec[1], sharex = ax1)

    ax1.tick_params(labelbottom = False, labelsize = fontsize)
    ax2.tick_params(labelsize = fontsize)

    ax1.plot(Nlist, Slist, label = '$S$ (V/K)', linestyle = '', marker = 'o')
#    ax1.set_xlabel("$N_{Hall}$ (cm$^{-3}$)", fontsize = fontsize)
    ax1.set_ylabel("$S$ (V/K)", fontsize = fontsize)
    ax1.set_xscale('log')

    ax2.plot(Nlist, meff_list, label = '$m_{eff}$', linestyle = '', marker = 's')
    ax2.set_xlabel("$N_{Hall}$ (cm$^{-3}$)", fontsize = fontsize)
    ax2.set_ylabel("$m_{eff}$", fontsize = fontsize)
    ax2.set_xscale('log')

    xlim = ax1.get_xlim()
    ax1.plot(xlim, [0.0, 0.0], linestyle = 'dashed', linewidth = 0.5, color = 'red')
    ax2.plot(xlim, [0.0, 0.0], linestyle = 'dashed', linewidth = 0.5, color = 'red')

#    ax1.legend(fontsize = legend_fontsize)
#    ax1b.legend(fontsize = legend_fontsize)

    plt.tight_layout()
    plt.pause(0.1)

    app.terminate("", usage = usage, pause = True)


if (__name__ == '__main__'):
    main()

