from pprint import pprint
import numpy as np
from numpy import log, log10, sqrt, exp
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt


from tklib.tkutils import print_data, pint, pfloat, format_strlist
from tklib.tksci.tksci import log10, e, kB
from tklib.tksci.tkFit_mxy_flex import tkFit_mxy


infile   = 'Hall-T.xlsx'
datafile = 'mu.xlsx'

max_mu = 1.0e10


figsize = (10, 8)
fontsize = 16
legend_fontsize = 12
colors = ['black', 'red', 'blue', 'green', 'orange', 'darkgreen', 'darkorange', 'navy', 
          'slategray', 'hotpink', 'olive', 'chocolate', 'magenta', 
          'green', 'yellow', 'cyan']
fplot = 1


def read_data(fit, infile, xlabel, ylabel):
    fit.read_data(infile, x_labels = [xlabel], y_labels = [ylabel])
    T_list  = fit.xdata_list[0]
    mu_list = fit.ydata_list[0]
    return fit.x_labels[0], fit.y_labels[0], fit.xdata_list[0], fit.ydata_list[0]

def convert_xk(xk_all, linid = None):
    Eb, s_phi, Eop, aop = xk_all[:4]
    pi = []
    ai = []
    ai_id = []
    for i in range(4, len(xk_all), 2):
        pi.append(xk_all[i])
        ai.append(xk_all[i+1])
        if linid is not None:
            if linid[i+1] == 1:
                ai_id.append(1)
            else:
                ai_id.append(0)

    if linid is None:
        return Eb, s_phi, Eop, aop, pi, ai
    else:
        return Eb, s_phi, Eop, aop, pi, ai, linid[3], ai_id
    
def lsqfunc(Eop, pi, idx, T):
#    print("pi=", Eop, pi)

    if idx == 0:
        if T == 0.0:
            return 0.0
        else:
            k = Eop * e / kB / T
            if k > 70.0:
                return 0.0
            else:
                return 1.0 / (exp(k) - 1.0)
    else:
        ip = idx - 1
        if ip >= len(pi):
            nfunc = int((len(self.varkeys) - 1) / 2)
            print(f"\nError in lsqfunc: Too many functions required (idx = {idx}, nfunc={nfunc})\n")
            exit()

        y = pow(T, -pi[ip])

        return y

def cal_linearpart(Tlist, mulist, Eb, s_phi):
    muinvlist = []
    for i, T in enumerate(Tlist):
        ekBT = e / kB / T
        muinvlist.append(exp(-Eb * ekBT + (s_phi * ekBT)**2 / 2.0)  / mulist[i])

    return muinvlist
    
def muinv_op(T, Eop, Aop):
    ekBT = e / kB / T
    return Aop / (exp(Eop * ekBT) - 1.0)

def muinv_p(T, ai, pi):
    return ai * pow(T, -pi)

def KGB(T, VB, s_phi):
    ekBT = e / kB / T
    return exp(-VB * ekBT + (s_phi * ekBT)**2 / 2.0) 

def cal_mu_components(T, Eb, s_phi, Eop, aop, pi, ai):
    ekBT = e / kB / T

    npoly = len(pi)
    Kgb = exp(-Eb * ekBT + (s_phi * ekBT)**2 / 2.0) 
    muop_inv = muinv_op(T, Eop, aop)
    if abs(muop_inv) > 1.0 / max_mu:
        muop = 1.0 / muop_inv
    else:
        muop = max_mu
    
    muinv_tot = muop_inv
    mu_pi  = []
    for i in range(npoly):
        muinv = muinv_p(T, ai[i], pi[i])
        muinv_tot += muinv
        if abs(muinv) > 1.0 / max_mu:
            mu_pi.append(1.0 / muinv)
        else:
            mu_pi.append(max_mu)

    mu_ingrain = 1.0 / muinv_tot
    mu_tot = Kgb * mu_ingrain

    return mu_tot, Kgb, mu_ingrain, muop, mu_pi

def cal_mu(T, Eb, s_phi, Eop, aop, pi, ai, rettype = 'tot'):
    mu_tot, mu_KGB, mu_ingrain, mu_op, mu_pi = cal_mu_components(T, Eb, s_phi, Eop, aop, pi, ai)

    if rettype == 'tot': return mu_tot
    return mu_tot, mu_KGB, mu_ingrain, mu_op, mu_pi


def cal_mu_list(T_list, Eb, s_phi, Eop, aop, pi, ai, rettype = 'tot'):
    mucal_list = []
    mutot_list = []
    muKGB_list = []
    muingrain_list = []
    muop_list = []
    mupi_list = []
    for T in T_list:
        if rettype == 'tot':
            mu_tot = cal_mu(T, Eb, s_phi, Eop, aop, pi, ai, rettype = 'tot')
        else:
            mu_tot, mu_KGB, mu_ingrain, mu_op, mu_pi = cal_mu(T, Eb, s_phi, Eop, aop, pi, ai, rettype = rettype)

        mutot_list.append(mu_tot)
        muKGB_list.append(mu_KGB)
        muingrain_list.append(mu_ingrain)
        muop_list.append(mu_op)
        mupi_list.append(mu_pi)

#転置
    mupi_list = [list(row) for row in zip(*mupi_list)]

    if rettype == 'tot': return mucal_list
    return mutot_list, muKGB_list, muingrain_list, muop_list, mupi_list

def save_data(path, labels, data_list):
    print()
    print(f"Save data to [{path}]")
    df = pd.DataFrame(np.array(data_list).T, columns = labels)
    df.to_excel(path, index = False, header = True) 

def plot(T_list, mu_list, Tcal_list = None, mucal_list = None):
    fig, ax = plt.subplots(1, 1, figsize = figsize)
    ax.tick_params(labelsize = fontsize)

    ax.plot(T_list,    mu_list,    label = 'obs', linestyle = '', marker = 'o')
    if Tcal_list and mucal_list:
        ax.plot(Tcal_list, mucal_list, label = 'cal')
#    ax.set_xscale('log')   
#    ax.set_yscale('log')   

    ax.set_xlabel('T (K)', fontsize = fontsize)
    ax.set_ylabel(r'$\mu$ (m$^2$/Vs)', fontsize = fontsize)

    ax.legend(fontsize = fontsize)

    plt.pause(1.0e-5)
    input("\nPress ENTER to terminate>>")


def plot_muT_decomposed(ax, pi, xT, ymu, ymucal = None, ymu_ingrain = None, ymuop = None, ymupi = None, markersize = 3.0):
    ax.plot(xT, ymu, label = r'$\mu(obs)$', linestyle = '', marker = 'o', markersize = markersize)

    ylim = None
    if ymucal:
        ylim = [0.0, max(ymucal) * 10.0]
        ax.plot(xT, ymucal, label = r'$\mu(cal)$',       color = 'red',  linewidth = 1.0, linestyle = '-')

    if ymu_ingrain and max(ymu_ingrain) < max_mu:
        color = colors[0]
        ax.plot(xT, ymu_ingrain, label = r'$\mu_{in-grain}$', linewidth = 1.0, linestyle = 'dashed', color = color)

    if ymuop:
#    if ymuop and max(ymuop) < max_mu:
        color = colors[1]
        ax.plot(xT, ymuop, label = r'$\mu_{op}$',       linewidth = 1.0, linestyle = 'dashed', color = color,
                        marker = 'o', markersize = 2.0)

    if ymupi:
        ioffset = 2
        for i, _ymu in enumerate(ymupi):
            if max(_ymu) >= max_mu: continue

            color = colors[i + ioffset]
            ax.plot(xT, _ymu, label = rf'$\mu$(pi={pi[i]})',       linewidth = 1.0, linestyle = 'dashed', color = color,
                        marker = 'o', markersize = 2.0)

    if ylim: ax.set_ylim(ylim)
    ax.set_xlabel('T (K)', fontsize = fontsize)
    ax.set_ylabel(r'$\mu$ (m$^2$/Vs)', fontsize = fontsize)

    ax.legend(fontsize = legend_fontsize)

def plot_muT_weight(ax, pi, xT, wmugb = None, wmu_ingrain = None, wmuop = None, wmupi = None, markersize = 3.0):
    color = colors[0]
    ax.plot(xT, wmugb,  label = r'$w_{GB}$',   linewidth = 1.0, linestyle = 'dashed', color = color)

    color = colors[1]
    ax.plot(xT, wmuop,  label = r'$w_{op}$',   linewidth = 1.0, linestyle = 'dashed', color = color,
                        marker = 'o', markersize = 2.0)

    if wmupi:
        ioffset = 2
        for i, wmu in enumerate(wmupi):
            color = colors[i + ioffset]
            ax.plot(xT, wmu, label = rf'w(pi={pi[i]})', linewidth = 1.0, linestyle = 'dashed', color = color,
                        marker = 'o', markersize = 2.0)

    ax.set_xlabel('T (K)', fontsize = fontsize)
    ax.set_ylabel(r'contribution', fontsize = fontsize)
    ax.legend(fontsize = legend_fontsize)

def cal_weight_list(T_list, pi, muKGB_list, muop_list, mupi_list):
    ywmugb  = []
    ywmuop  = []
    ywmupi  = []
    print("")
    print(f"{'T(K)':6} {'w,gb':10} {'w,op':10} ", end = '')
    for i, mu in enumerate(mupi_list): 
        s = f'w({pi[i]})'
        print(f"{s:10} ", end = '')
    print()

    ywmupi = []
    for iT, T in enumerate(T_list):
        totinv = 1.0 / abs(muop_list[iT])
        for mu in mupi_list:
            totinv += 1.0 / abs(mu[iT])

        ywmugb.append(1.0 - muKGB_list[iT])

        wmuop = muKGB_list[iT] / muop_list[iT] / totinv
        ywmuop.append(wmuop)

        _wmupi = []
        for mu in mupi_list:
            wmupi = muKGB_list[iT] / mu[iT] / totinv
            _wmupi.append(wmupi)
        ywmupi.append(_wmupi)
        
        print(f"{T:6.3g} {ywmugb[iT]:10.4g} {ywmuop[iT]:10.4g} ", end = '')
        for i, mu in enumerate(mupi_list): print(f"{_wmupi[i]:10.4g} ", end = '')
        print()

#転置
    ywmupi = [list(row) for row in zip(*ywmupi)]
    
    return ywmugb, ywmuop, ywmupi


def main():
    infile = "Hall-T.xlsx"
    xlabel = 'T'
    ylabel = "mu"

    fit = tkFit_mxy()
    print("")
    print("infile : ", infile)
    print("xlabel : ", xlabel)
    print("ylabel : ", ylabel)
    
    print("")
    print(f"Read data from [{infile}]")
    Tlabel, mulabel, T_list, mu_list = read_data(fit, infile, xlabel, ylabel)

    print("x_labels=", fit.x_labels)
    print("y_labels=", fit.y_labels)
    print("T:", T_list)
    print("mu:", mu_list)

    Eb = 0.02
    s_phi = 0.0
    Eop = 0.0446
    aop = 0.0
    pi    = [          0.0, -1.0, 1.0, 1.5]
    ai    = [   0.00725428,  0.0, 0.0, 0.0]
    optid = [1,          1,    1,   1,   1]

#    mucal_list = cal_mu_list(T_list, Eb, s_phi, Eop, aop, pi, ai)
    mucal_list, muKGB_list, muingrain_list, muop_list, mupi_list = cal_mu_list(T_list, Eb, s_phi, Eop, aop, pi, ai, rettype = 'all')
    print("")
    print(f"{'T(K)':6} {'KGB':10} {'mu,op':10} ", end = '')
    for i, mu in enumerate(mupi_list): 
        s = f'mu({pi[i]})'
        print(f"{s:10} ", end = '')
    print()

    for iT, T in enumerate(T_list):
        KGB = muKGB_list[iT]
        muop = muop_list[iT]

        print(f"{T:6.3g} {KGB:10.4g} {muop:10.4g} ", end = '')
        for i, mu in enumerate(mupi_list): print(f"{mu[iT]:10.4g} ", end = '')
        print()

    
    ywmugb, ywmuop, ywmupi = cal_weight_list(T_list, pi, muKGB_list, muop_list, mupi_list)

#====
# plot
    fig, axes = plt.subplots(1, 2, figsize = figsize)
    axes[0].tick_params(labelsize = fontsize)
    axes[1].tick_params(labelsize = fontsize)

#    plot(T_list, mu_list, T_list, mucal_list)
    plot_muT_decomposed(axes[0], pi, T_list, mu_list, ymucal = mucal_list, ymu_ingrain = muingrain_list, ymuop = muop_list, ymupi = mupi_list)
    plot_muT_weight(axes[1], pi, T_list, wmugb = ywmugb, wmu_ingrain = None, wmuop = ywmuop, wmupi = ywmupi)
    
    plt.tight_layout()
    plt.pause(1.0e-5)
    input("\nPress ENTER to terminate>>")


if __name__ == '__main__':
    main()