"""
## Information
Author: Takumi Nose
Modified by: Haruto Minamishima、Toshio Kamiya
Date: 2024-1-15、2024-3-8
Version: 1.1.0

Description: Calculate effective mass from BoltzTraP2 output
"""

import os
import re
import argparse
from pathlib import Path
import numpy as np
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt


###############################
e_dis = [0.0, 0.1, 0.2, 0.3] #eV
Emin = 0.05
Emax = 0.4
Estep = 0.01

T = 300 #K

outxlsx = 'mass_boltztrap2.xlsx'


h = 6.62607015e-34
k_B = 1.380649e-23 #J/K
r = 0
e = 1.602176634e-19 #C
m0 = 9.1093837015e-31 #kg 

# 緩和時間は10fsで決め打ち
tau = 1e-14 #s

#PF = seebeck ^ 2 * sigma
###############################


def main():
    cwd = Path.cwd()
    original_dir = str(cwd)

    os.chdir(os.path.dirname(os.path.abspath(__file__)))
    filedir = Path.cwd()
    matnamepath = str(cwd.parent)
    matname = re.search("mp-\d+_\S+$", matnamepath)
    m_list = []
    if matname:
        m_list.append(matname.group(0))
    else:
        m_list.append("sample")
        
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("-m", "--mode", help="trace or both ('both' means trace and tensor) ", default="both")
    parser.add_argument("-t", "--outputtype", help="output type ('' or plain)", default='')
    parser.add_argument("-f", "--filedir", help="output dir ('')", default='')
    args = parser.parse_args()
    if args.outputtype == 'plain':
        filedir=""
        os.chdir(original_dir)
    else:
        filedir = args.filedir

    if args.mode == "both" or args.mode == 'plot':
        calc_both(filedir, cwd, e_dis, m_list, mode = args.mode)
    elif args.mode == "trace":
        calc_trace(filedir, cwd, e_dis, m_list)
    elif args.mode == "tensor":
        calc_tens(filedir, cwd, e_dis, m_list)
    else:
        print("incorrect mode")


def get_bandedges(cwdir:Path):
    """ 
    OUTCARとEIGENVALからバンド端の情報を取得する\n

    Args:
        cwdir (Path): OUTCARとEIGENVALがあるパス
    Returns:
        eup: _description_
        edown: _description_
        vol: _description_
    """
    pathOUTCAR = cwdir / "OUTCAR"
    for line in open(str(pathOUTCAR), 'r'):
        if line.find('E-fermi') > -1:
            ef=float(line.split()[2])

    for line in open(str(pathOUTCAR), "r"):
        if line.find("volume of cell") > -1:
            vol = float(line.split()[4])
            break
    pathEIGENVAL = cwdir / "EIGENVAL"
    with open(str(pathEIGENVAL)) as f:
        lines = f.readlines()

    nspin=int(lines[0].split()[3])
    nkpoint=int(lines[5].split()[1])
    nband=int(lines[5].split()[2])

    eup=ef+99.0
    edown=ef-99.0
    k=[]
    for i in range(nkpoint):
        k.append(lines[7+i*(nband+2)].split()[0:3])
        for j in range(nband):
            for s in range(nspin):
                eval=float(lines[8+i*(nband+2)+j].split()[s+1])
                d=eval-ef 
                if (d>0):
                    if (eval<eup):
                        eup=eval
                        kup=i
                        bup=j
                        sup=s
                if (d<0):
                    if (eval>edown):
                        edown=eval
                        kdown=i
                        bdown=j
                        sdown=s
    return (eup, edown, vol)



def raw2df_trace(cwdir:Path, vol:float):
    """    
    interpolation.traceを読み込み、単位を変換したDataFrameを返す\n
    EFの単位をRyからeVに、Nの単位をe/ucから1/m3にしている\n
    後者の変換でget_bandedges()で得られるvol(セルの体積)が必要\n

    Args:
        cwdir (Path): interpolation.traceがあるディレクトリのパス\n
        vol (float): セルの体積(get_bandedges()で得られる)

    Returns:
        df:pd.Dataframe: 計算結果が入ったDataframe
    """
    path_result = cwdir / "interpolation.trace"
    df = pd.read_table(str(path_result), header=None, skiprows=1, delim_whitespace=True)
    df.columns = ["EF[Ry]", "T", "N[e/uc]", "DOS", "S", "sigma/tau0", "RH", "k/tau0", "cv", "chi"]
    df.loc[:,"EF[eV]"] = df.loc[:,"EF[Ry]"] * 13.6057
    df.loc[:,"N[1/m3]"] = df.loc[:,"N[e/uc]"] / vol * 1E30
    return df



def raw2df_tens(cwdir, vol):
    """
    interpolation.condtensを読み込み、単位を変換したDataFrameを返す\n
    EFの単位をRyからeVに、Nの単位をe/ucから1/m3にしている\n
    後者の変換でget_bandedges()で得られるvol(セルの体積)が必要\n

    Args:
        cwdir (Path): interpolation.condtensがあるディレクトリのパス\n
        vol (float): セルの体積(get_bandedges()で得られる)

    Returns:
        df:pd.Dataframe: 計算結果が入ったDataframe
    """
    path_result = cwdir / "interpolation.condtens"
    df_tensor = pd.read_table(str(path_result), header=None, skiprows=1, delim_whitespace=True)
    df_tensor.columns = ["EF[Ry]", "T", "N[e/uc]", "sigma/tau0xx", "sigma/tau0xy", "sigma/tau0xz", "sigma/tau0yx", "sigma/tau0yy", "sigma/tau0yz", "sigma/tau0zx", "sigma/tau0zy", "sigma/tau0zz",\
        "Sxx", "Sxy", "Sxz", "Syx", "Syy", "Syz", "Szx", "Szy", "Szz", "k/tau0xx", "k/tau0xy", "k/tau0xz", "k/tau0yx", "k/tau0yy", "k/tau0yz", "k/tau0zx", "k/tau0zy", "k/tau0zz" ]
    df_tensor.loc[:,"EF[eV]"] = df_tensor.loc[:,"EF[Ry]"] * 13.6057
    df_tensor.loc[:,"N[1/m3]"] = df_tensor.loc[:,"N[e/uc]"] / vol * 1E30
    return df_tensor



def get_para_cbm(cbm, dis, df, mode):
    """
    cbmについて、有効質量計算用のパラメータを計算する\n
    Args:
        cbm (_type_): _description_
        dis (_type_): _description_
        df (_type_): _description_
        mode (_type_): "xx" or "yy" or "zz"
        
    Returns:
        n:キャリア密度\n
        S:ゼーベック係数\n
        sigmatau:電気伝導度\n
    """

    df.loc[:,"slide"] = df.loc[:,"EF[eV]"] - cbm + dis
    npoint = df.loc[:,"slide"].abs().idxmin()
    n = df.at[npoint, "N[1/m3]"]
    Sdir = "S" + mode
    S = df.at[npoint, Sdir]
    sigmataudir = "sigma/tau0" + mode
    sigmatau = df.at[npoint, sigmataudir]
    return (n, S, sigmatau)



def get_para_vbm(vbm, dis, df, mode):
    """ 
    vbmについて、有効質量計算用のパラメータを計算する\n
    Args:
        vbm (_type_): _description_
        dis (_type_): _description_
        df (_type_): _description_
        mode (_type_): "xx" or "yy" or "zz"

    Returns:
        n:キャリア密度\n
        S:ゼーベック係数\n
        sigmatau:電気伝導度\n
    """
    df.loc[:,"slide"] = df.loc[:,"EF[eV]"] - vbm - dis
    npoint = df.loc[:,"slide"].abs().idxmin()
    n = df.at[npoint, "N[1/m3]"]
    Sdir = "S" + mode
    S = df.at[npoint, Sdir]
    sigmataudir = "sigma/tau0" + mode
    sigmatau = df.at[npoint, sigmataudir]
    return (n, S, sigmatau)



def calc_md(n, S) -> float:
    """
    キャリア濃度nとゼーベック係数Sから状態密度有効質量を計算する

    Args:
        n (_type_): キャリア濃度
        S (_type_): ゼーベック係数

    Returns:
        float: 状態密度有効質量
    """

    md = (h**2 / (2 * np.pi * k_B * T)) * np.power(abs(0.5 * n * np.exp(e * abs(S) / k_B - r - 2)), 2/3) / m0
    return md

def calc_m(n, sigmatau):
    """ 
    キャリア濃度nと電気伝導度sigmatauから有効質量を計算する
    
    Args:
        n (_type_): キャリア濃度\n
        sigmatau (_type_): 電気伝導度/緩和時間

    Returns:
        float: バンド有効質量
    """

    m = abs(n * e**2 * (1/sigmatau)) / m0
    return m

def calc_multiplicity(md, m):
    """状態密度有効質量mdとバンド有効質量mから多重度を計算\n

    Args:
        md (float): 状態密度有効質量\n
        m (float): バンド有効質量

    Returns:
        float: 多重度
    """

    multiplicity = np.power(md/m, 3/2)
    return multiplicity



def at_one_ev(n, S, sigmatau):
    """
    キャリア密度n、ゼーベック係数S、電気伝導度sigmatauを渡して\n
    ・状態密度有効質量md\n
    ・バンド有効質量m\n
    ・多重度Multiplicity(M)\n
    を計算した結果を返す

    Args:
        n (float): キャリア密度\n
        S (float): ゼーベック係数\n
        sigmatau (float): 電気伝導度/緩和時間

    Returns:
        tuple: [
            md:状態密度有効質量\n
            m:バンド有効質量\n
            multiplicity:多重度
        ]
    """

    md = calc_md(n, S)
    m  = calc_m(n, sigmatau)
    multiplicity = calc_multiplicity(md, m)
    return (md, m, multiplicity)

def cal_meff(target, vasp_result, E, df, mode):
    if target == 'cb':
       n, S, sigma = get_para_cbm(vasp_result, E, df, mode)
    else:
       n, S, sigma = get_para_vbm(vasp_result, E, df, mode)
 
    mDOS, mband, M = at_one_ev(n, S, sigma)
 
    return mDOS, mband, M


def calc_trace(filedir, cwdir, edis, mlist):
    """
    trace成分について、有効質量と多重度を計算しexcelに書き込む

    Args:
        filedir (Path): _description_
        cwdir (Path): _description_
        edis (list[float]): _description_
        mlist (list): _description_
    """
    vasp_result = get_bandedges(cwdir)
    df = raw2df_trace(cwdir, vasp_result[2])
    for e in edis:
        param_cbm = get_para_cbm(vasp_result[0], i, df, "")
        result_cbm = at_one_ev(param_cbm[0], param_cbm[1], param_cbm[2])
        mlist = mlist + [result_cbm[0], result_cbm[1], result_cbm[2]]
    mlist.append("")
    for i in edis:
        param_vbm = get_para_vbm(vasp_result[1], i, df, "")
        result_vbm = at_one_ev(param_vbm[0], param_vbm[1], param_vbm[2])
        mlist = mlist + [result_vbm[0], result_vbm[1], result_vbm[2]]
    xlsxdir = filedir / "md_list_trace.xlsx"
    wb = openpyxl.load_workbook(str(xlsxdir))
    ws = wb.active
    ws.append(mlist)
    wb.save(str(xlsxdir))



def calc_tens(filedir, cwdir, edis, mlist):
    """
    tensor成分について、有効質量と多重度を計算しexcelに書き込む


    Args:
        filedir (Path): _description_
        cwdir (Path): _description_
        edis (list[float]): _description_
        mlist (list): _description_
    """
    vasp_result = get_bandedges(cwdir)
    df_tens = raw2df_tens(cwdir, vasp_result[2])
    for i in edis:
        param_cbm_xx = get_para_cbm(vasp_result[0], i, df_tens, "xx")
        result_cbm_xx = at_one_ev(param_cbm_xx[0], param_cbm_xx[1], param_cbm_xx[2])
        param_cbm_yy = get_para_cbm(vasp_result[0], i, df_tens, "yy")
        result_cbm_yy = at_one_ev(param_cbm_yy[0], param_cbm_yy[1], param_cbm_yy[2])        
        param_cbm_zz = get_para_cbm(vasp_result[0], i, df_tens, "zz")
        result_cbm_zz = at_one_ev(param_cbm_zz[0], param_cbm_zz[1], param_cbm_zz[2])
        mlist = mlist + [result_cbm_xx[0], result_cbm_yy[0], result_cbm_zz[0],\
            result_cbm_xx[1], result_cbm_yy[1], result_cbm_zz[1], result_cbm_xx[2], result_cbm_yy[2], result_cbm_zz[2]]
    mlist.append("")
    for i in edis:
        param_vbm_xx = get_para_vbm(vasp_result[1], i, df_tens, "xx")
        result_vbm_xx = at_one_ev(param_vbm_xx[0], param_vbm_xx[1], param_vbm_xx[2])
        param_vbm_yy = get_para_vbm(vasp_result[1], i, df_tens, "yy")
        result_vbm_yy = at_one_ev(param_vbm_yy[0], param_vbm_yy[1], param_vbm_yy[2])        
        param_vbm_zz = get_para_vbm(vasp_result[1], i, df_tens, "zz")
        result_vbm_zz = at_one_ev(param_vbm_zz[0], param_vbm_zz[1], param_vbm_zz[2])
        mlist = mlist + [result_vbm_xx[0], result_vbm_yy[0], result_vbm_zz[0],\
            result_vbm_xx[1], result_vbm_yy[1], result_vbm_zz[1], result_vbm_xx[2], result_vbm_yy[2], result_vbm_zz[2]]
    xlsxdir = filedir / "md_list_tens.xlsx"
    wb = openpyxl.load_workbook(str(xlsxdir))
    ws = wb.active
    ws.append(mlist)
    wb.save(str(xlsxdir))



def calc_both(filedir, cwdir, edis, mlist, mode = 'both'):
    """
    traceとtensor両方について、有効質量と多重度を計算しexcelに書き込む

    Args:
        filedir (Path): _description_
        cwdir (Path): _description_
        edis (list[float]): _description_
        mlist (list): _description_
        mode (str): plot E-Eedge - m* graph
    """

    vasp_result = get_bandedges(cwdir)
    df = raw2df_trace(cwdir, vasp_result[2])
    df_tens = raw2df_tens(cwdir, vasp_result[2])
    labels = ["sample"]
    for E in edis:
        result_cbm    = cal_meff('cb', vasp_result[0], E, df, "")
        result_cbm_xx = cal_meff('cb', vasp_result[0], E, df_tens, "xx")
        result_cbm_yy = cal_meff('cb', vasp_result[0], E, df_tens, "yy")
        result_cbm_zz = cal_meff('cb', vasp_result[0], E, df_tens, "zz")

        mlist.extend([result_cbm[0], result_cbm_xx[0], result_cbm_yy[0], result_cbm_zz[0],
                      result_cbm[1], result_cbm_xx[1], result_cbm_yy[1], result_cbm_zz[1], 
                      result_cbm[2], result_cbm_xx[2], result_cbm_yy[2], result_cbm_zz[2]
                    ])
        labels.extend([f"e={E}:mde*", f"e={E}:mde*_xx", f"e={E}:mde*_yy", f"e={E}:mde*_zz",
                       f"e={E}:me*",  f"e={E}:me*_xx",  f"e={E}:me*_yy",  f"e={E}:me*_zz",
                       f"e={E}:Me",   f"e={E}:Me_xx",   f"e={E}:Me_yy",   f"e={E}:Me_zz",
                     ])

    mlist.append("")
    for E in edis:
        result_vbm    = cal_meff('vb', vasp_result[0], E, df, "")
        result_vbm_xx = cal_meff('vb', vasp_result[0], E, df_tens, "xx")
        result_vbm_yy = cal_meff('vb', vasp_result[0], E, df_tens, "yy")
        result_vbm_zz = cal_meff('vb', vasp_result[0], E, df_tens, "zz")

        mlist.extend([result_vbm[0], result_vbm_xx[0], result_vbm_yy[0], result_vbm_zz[0],
                      result_vbm[1], result_vbm_xx[1], result_vbm_yy[1], result_vbm_zz[1], 
                      result_vbm[2], result_vbm_xx[2], result_vbm_yy[2], result_vbm_zz[2]
                    ])
        labels.extend([f"e={E}:mdh*", f"e={E}:mdh*_xx", f"e={E}:mdh*_yy", f"e={E}:mdh*_zz",
                       f"e={E}:mh*",  f"e={E}:mh*_xx",  f"e={E}:mh*_yy",  f"e={E}:mh*_zz",
                       f"e={E}:Mh",   f"e={E}:Mh_xx",   f"e={E}:Mh_yy",   f"e={E}:Mh_zz",
                     ])

    xlsxdir = os.path.join(str(filedir), "md_list.xlsx")
    print()
    if os.path.isfile(xlsxdir):
        print(f"[{xlsxdir}] exists")
        wb = openpyxl.load_workbook(str(xlsxdir))
        ws = wb.active
        ws.append(mlist)
    
        print()
        print(f"Add m* data to [{xlsxdir}]")
        try:
            wb.save(str(xlsxdir))
        except:
            print()
            print(f"  Warning: Could not save [{xlsxdir}]")
    else:
        outfile = 'mass.txt'

        print()
        print(f"[{xlsxdir}] does not exist")
        print(f"Save m* data to [{outfile}]")
        try:
            fp = open(outfile, 'w')
            for l, m in zip(labels, mlist):
                fp.write(f"{l}:{m}\n")
            fp.close()
        except:
            print()
            print(f"  Warning: Could not write to [{outfile}]")
            fp = None

    nE = int((Emax - Emin) / Estep + 1.00001)
    Elist = np.arange(Emin, Emax + 1.0e-4, Estep)
#    print("E range: ", Emin + Estep, Emax, Estep, nE)
#    print("Elist: ", Elist)
    labels = ["E (eV)", f"<mdos_e*>", f"mdos_e*_xx", f"mdos_e*_yy", f"mdos_e*_zz",
                        f"<m_e*>",    f"m_e*_xx",    f"m_e*_yy",    f"m_e*_zz",
                        f"<Mult_e>",  f"Mult_e_xx",  f"Mult_e_yy",  f"Mult_e_zz",
                        f"<mdos_h*>", f"mdos_h*_xx", f"mdos_h*_yy", f"mdos_h*_zz",
                        f"<m_h*>",    f"m_h*_xx",    f"m_h*_yy",    f"m_h*_zz",
                        f"<Mult_h>",  f"Mult_h_xx",  f"Mult_h_yy",  f"Mult_h_zz",
             ]

    data_list = np.empty([nE, 25])
    for iE, E in enumerate(Elist):
        print("E=", iE, E)
        result_cbm    = cal_meff('cb', vasp_result[0], E, df, "")
        result_cbm_xx = cal_meff('cb', vasp_result[0], E, df_tens, "xx")
        result_cbm_yy = cal_meff('cb', vasp_result[0], E, df_tens, "yy")
        result_cbm_zz = cal_meff('cb', vasp_result[0], E, df_tens, "zz")

        result_vbm    = cal_meff('vb', vasp_result[0], E, df, "")
        result_vbm_xx = cal_meff('vb', vasp_result[0], E, df_tens, "xx")
        result_vbm_yy = cal_meff('vb', vasp_result[0], E, df_tens, "yy")
        result_vbm_zz = cal_meff('vb', vasp_result[0], E, df_tens, "zz")

        data_list[iE] = [E, result_cbm[0], result_cbm_xx[0], result_cbm_yy[0], result_cbm_zz[0],
                         result_cbm[1], result_cbm_xx[1], result_cbm_yy[1], result_cbm_zz[1],
                         result_cbm[2], result_cbm_xx[2], result_cbm_yy[2], result_cbm_zz[2],
                         result_vbm[0], result_vbm_xx[0], result_vbm_yy[0], result_vbm_zz[0],
                         result_vbm[1], result_vbm_xx[1], result_vbm_yy[1], result_vbm_zz[1],
                         result_vbm[2], result_vbm_xx[2], result_vbm_yy[2], result_vbm_zz[2]]

        
    outxlsx_path = os.path.join(str(cwdir), outxlsx)
    print()
    print(f"Save effective masses/M to [{outxlsx_path}]")
    print("labels:",labels)
#    df = pd.DataFrame(data_list, columns = labels)
#    ret = df.to_excel(outxlsx_path, index = False)
#    print("ret=", ret)
    
    wb = openpyxl.Workbook()
    ws = wb.active
    ws.append(labels)
    for data in data_list:
        data = data.tolist()
        ws.append(data)
    wb.save(outxlsx_path)
    
    print()
    print("plot")
    figsize = (12, 8)
    fig, ax = plt.subplots(2, 3, figsize = figsize)

    ax[0, 0].set_title(cwdir)

    for i in range(1, 5):
        ax[0, 0].plot(data_list[:,0], data_list[:,i], label = labels[i])
    ax[0, 0].set_xlabel("$E - E_{CBM}$ (eV)")
    ax[0, 0].set_ylabel("Electron DOS m*")
    ax[0, 0].legend()
 
    for i in range(5, 9):
        ax[0, 1].plot(data_list[:,0], data_list[:,i], label = labels[i])
#    ax[0, 1].set_xlabel("$E - E_{CBM}$ (eV)")
    ax[0, 1].set_ylabel("Electron band m*")
    ax[0, 1].legend()

    for i in range(9, 13):
        ax[0, 2].plot(data_list[:,0], data_list[:,i],  label = labels[i])
    ax[0, 2].set_xlabel("$E - E_{CBM}$ (eV)")
    ax[0, 2].set_ylabel("CB multiplicity")
    ax[0, 2].legend()

    for i in range(14, 17):
        ax[1, 0].plot(data_list[:,0], data_list[:,i], label = labels[i])
    ax[1, 0].set_xlabel("$E_{VBM} - E$ (eV)")
    ax[1, 0].set_ylabel("Hole DOS m*")
    ax[1, 0].legend()
 
    for i in range(17, 21):
        ax[1, 1].plot(data_list[:,0], data_list[:,i], label = labels[i])
    ax[1, 1].set_xlabel("$E_{VBM} - E$ (eV)")
    ax[1, 1].set_ylabel("Hole band m*")
    ax[1, 1].legend()

    for i in range(21, 25):
        ax[1, 2].plot(data_list[:,0], data_list[:,i],  label = labels[i])
    ax[1, 2].set_xlabel("$E_{VBM} - E$ (eV)")
    ax[1, 2].set_ylabel("VB multiplicity")
    ax[1, 2].legend()
 
    plt.tight_layout
    plt.pause(0.0001)
    input("Press Enter to terminate>>")



if __name__ == "__main__":
    main()
