#!/usr/bin/env python

"""
## Information
Author: Xinyi He
Modified by: Toshio Kamiya
Date: 2024-1-17
Version: 1.0.0
Git URL: 
Description: Plot ALMODE ZT tensor average
"""

#use like: python3 plotZTave.py interpolation.condtens 

import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as patches
import subprocess
import re
import os
from pymatgen.io.vasp.inputs import Poscar


from tklib.tkcrystal.tkvasp import tkVASP



####revise this part##########################################
T=300    #K
tau=10e-15 #10fs
klat=1   #W/mK
klaty=1   #W/mK
klatz=1   #W/mK

#x and y axis range
xmini=-1 #eV
xmaxm=2 #eV
nmax=8E21 #cm^-3
kappamax= 10 #W/mK
sigmamax= 2E4  #S/cm
Seebeckmax=3  #mV/K
PFmax=30 #uW/cmK^2
ZTmax=1.5

########################################################################


def pfloat(s, defval = None):
    try:
        return float(s)
    except:
        pass
    return defval


argv = sys.argv
nargs = len(argv)
if nargs == 0:
    print("Usage: python plotZTave.py interpolation.condtens T kappa_lattice(x/ave) kappa_lattice(y) kappa_lattice(z)")
    exit(1)
else:
    offset = 0
    files = []
    for i in range(1, nargs):
        f = pfloat(argv[i], None)
        if f is None:
            print("f=", i, argv[i])
            files.append(argv[i])
            offset = i
            break

    if nargs > 1 + i:
        T = pfloat(argv[1 + i], T)
    if nargs > 2 + i:
        tau = pfloat(argv[2 + i], tau)
    if nargs > 3 + i:
        klat = pfloat(argv[3 + i], klat)
    if nargs > 4 + i:
        klaty = pfloat(argv[4 + i], klaty)
    if nargs > 5 + i:
        klatz = pfloat(argv[5 + i], klatz)

print()
print("files:")
for f in files:
    print(f"  {f}")
print(f"T: {T} K")
print(f"tau: {tau} s")
print(f"kappa_lattice(x/ave): {klat} W/m/K")
print(f"kappa_lattice(y)    : {klaty} W/m/K")
print(f"kappa_lattice(z)    : {klatz} W/m/K")


left=0.2;right=0.95;top=0.94;bottom=0.15
#command = ("vaspkit -task 911 > bandvaspkit;vaspkit -task 601 > lattice; vaspkit -task 911 > bandedge;")
#os.system(command)
#efermi=float(subprocess.check_output("grep 'Eigenvalue of VBM' bandedge | awk \'{print $5}\'", shell=True).decode().strip())
##efermi = float(subprocess.check_output("grep E-f OUTCAR | awk \'{print $3}\'", shell=True).decode().strip())  # aligned to Ef 
#gap = float(subprocess.check_output("grep Gap bandvaspkit | awk \'{print $4}\'", shell=True).decode().strip())
#poscar = Poscar.from_file("CONTCAR")
## Get the lattice vectors from the CONTCAR
#lattice = poscar.structure.lattice
## Calculate the volume of the unit cell
#Volume = lattice.volume
#V=Volume* (10**-24);

CAR_path = '.'
vasp = tkVASP()
base_path = vasp.getdir(CAR_path)

print("")
POSCAR_path  = vasp.get_POSCAR(base_path)
OUTCAR_path  = vasp.get_OUTCAR(base_path)
EIGENVAL_path = vasp.get_VASPPath(base_path, 'EIGENVAL')
print("CAR dir(ideal) : ", CAR_path)
print("  POSCAR  : ", POSCAR_path)
print("  OUTCAR  : ", OUTCAR_path)
print("  EIGENVAL: ", EIGENVAL_path)

cry = vasp.read_poscar(POSCAR_path)
lattice = cry.LatticeParameters()
Volume = cry.Volume()
V = Volume * 1.0e-24

outcarinf = vasp.read_outcar_inf(OUTCAR_path)
efermi = outcarinf["EF"]
#bandedgeinf = vasp.find_band_edges_from_eigenval(EF0 = EF0, eigenvalinf = eigenvalinf, ISPIN = ISPIN, occ_th = occ_th)
##print("bandedgeinf=", bandedgeinf)
bandedgeinf2 = vasp.gbandedges(CAR_path, OUTCAR_path, EIGENVAL_path, efermi)
gap = bandedgeinf2["Eg"]

print(f"Fermi level: {efermi} eV")
print(f"Bandgap: {gap} eV")
print(f"Volume: {Volume}E-24 cm^-3")

# font styles
mpl.rc('font', **{'family': 'arial', 'sans-serif': ['Helvetica']})
# line colors and styles
color = ['k', 'b', 'g', 'r', 'm', 'c', 'y', 'r',
         'darkred', 'darkblue', 'darkgreen', 'darkmagenta']
lsty = ['-', '-', '-', '-', '--', '--', '--', '--', '-', '-', '-', '-']

def change_xscale(array):   
    for i in range(len(array)):
        array[i] = array[i] * 13.6 - efermi   
    return array

if __name__ == '__main__':

    energy_axis = []
    dos_merged = []

    for file in files:
        data_tmp = np.loadtxt(file, dtype=float)
        energy_axis.append(data_tmp[:, 0])
        dos_merged.append(data_tmp[:, 1:])

    energy_axis = change_xscale(energy_axis)

    fig, axes = plt.subplots(6, 1, figsize=(4, 6), sharex=True, gridspec_kw={'hspace': 0.2})  # 2つの縦方向のサブプロットを作成し、X軸を共有
#    plt.figure(figsize=(4, 6)) #,dpi=300)
#n    
#    plt.subplot(611)   
    for i in range(len(dos_merged)):
        axes[0].plot(energy_axis[i][:], abs(dos_merged[i][:,1]/V),
                  color='black')     
     
    nmax = axes[0].get_ylim()[1]
    axes[0].text(xmini+0.05, nmax*0.01, '|n| (cm$^{-3}}$)', size = 10, alpha = 1)
#    axes[0].text(5, 5, "S[uV/K]", size = 11, rotation=90, alpha = 0.2)
#    plt.xticks(np.arange(-10,10,1),fontsize=10,color='white');plt.xlim((xmini, xmaxm))
#    plt.gca().add_patch(patches.Rectangle((0, 0), gap,nmax,alpha = 0.1))
#    axes[0].text(xmini+0.05, nmax*0.5/4, '|n| (cm$^{-3}}$)', size = 10, alpha = 1)
#    yticks = np.arange(1.0e18, nmax*1.1, 1.0e20)
#    axes[0].set_yticks(yticks, fontsize=10)
    axes[0].set_ylim((1.0e18, nmax))

    axes[0].set_yscale("log")
    

##############################
    
#kappa
#    plt.subplot(612)
    for i in range(len(dos_merged)):
        axes[1].plot(energy_axis[i][:], (dos_merged[i][:,20]*tau)/3+(dos_merged[i][:,24]*tau)/3+(dos_merged[i][:,28]*tau)/3,#kapper:20, S:11,|n|:1, sigma:2
                  color='red')                     
#    plt.xticks(np.arange(-10, 10,1),fontsize=10,color='white');plt.xlim((xmini, xmaxm))
    axes[1].set_ylim(ymin=0)
    kappamax = axes[1].get_ylim()[1]
#    plt.yticks(np.arange(0, kappamax+10, kappamax/2), fontsize=10);plt.ylim(( 0, kappamax)); 
#    plt.gca().add_patch(patches.Rectangle((0, 0), gap,kappamax,alpha = 0.1))
    axes[1].text(xmini+0.05, kappamax*4/5, r'$\kappa$${_e}$ (W/(mK))', size = 10, alpha = 1)

    kappamax = axes[1].get_ylim()[1]
#    axes[1].text(0.5, kappamax*3/4, 'a-axis', color="red",size = 10, alpha = 1)
#    axes[1].text(0.5, kappamax*2/4, 'b-axis', color="blue",size = 10, alpha = 1)
#    axes[1].text(0.5, kappamax*1/4, 'c-axis', color="green",size = 10, alpha = 1)

##############################
##sigma
    #plt.figure(figsize=(3.4, 1.8)) #,dpi=300)
#    plt.subplot(613)
    for i in range(len(dos_merged)):
        axes[2].plot(energy_axis[i][:], (dos_merged[i][:,2]*tau/100)/3+(dos_merged[i][:,6]*tau/100)/3+( dos_merged[i][:,10]*tau/100)/3,#kapper:20, S:11,|n|:1, sigma:2
                  color='red')                 
#    plt.xticks(np.arange(-10, 10,1),fontsize=10,color='white');plt.xlim((xmini, xmaxm))
    axes[2].set_ylim(ymin=0)
#    plt.yticks(np.arange(0, sigmamax*1.1, sigmamax/2), fontsize=10);
    axes[2].set_ylim((0, sigmamax)); 
#    plt.gca().add_patch(patches.Rectangle((0, 0), gap,sigmamax,alpha =0.1 ))
    axes[2].text(xmini+0.05, sigmamax*4/5, r'$\sigma$' + " (" + r'S/cm'+")", size = 10, alpha = 1)
#    plt.ticklabel_format(style='sci',scilimits=(-1,2), axis='y', useLocale=True, useMathText=True)
 
#############################################
##Seebeck
#    plt.subplot(614)    
    for i in range(len(dos_merged)):
        axes[3].plot(energy_axis[i][:], (1000*dos_merged[i][:,11])/3+(1000*dos_merged[i][:,15])/3+(1000*dos_merged[i][:,19])/3,#kapper:20, S:11,|n|:1, sigma:2
                  color='red')  
    Seebeckmax = axes[3].get_ylim()[1]
#    plt.xticks(np.arange(-10, 10,1),fontsize=10,color='white');plt.xlim((xmini, xmaxm))
#    plt.yticks(np.arange(-Seebeckmax, Seebeckmax+0.1, Seebeckmax), fontsize=10); plt.ylim(( -Seebeckmax,Seebeckmax));
#    plt.gca().add_patch(patches.Rectangle((0, -Seebeckmax), gap,Seebeckmax*2,alpha = 0.1))
    axes[3].text(xmini+0.05, Seebeckmax*1/2, r'S (mV/K)', size = 10, alpha = 1)
   
##############################
##power factor
  
#    plt.subplot(615)
    for i in range(len(dos_merged)):
        axes[4].plot(energy_axis[i][:], 
                (10000*dos_merged[i][:,11]*dos_merged[i][:,11]*dos_merged[i][:,2]*tau)/3+(10000*dos_merged[i][:,15]*dos_merged[i][:,15]
                 *dos_merged[i][:,6]*tau)/3+(10000*dos_merged[i][:,19]*dos_merged[i][:,19]*dos_merged[i][:,10]*tau)/3, #kapper:20, S:11,|n|:1, sigma:2
                  color='red')           
    axes[4].set_ylim(ymin=0)
    PFmax = axes[4].get_ylim()[1]
#    plt.xticks(np.arange(-10, 10,1),fontsize=10,color='white');plt.xlim((xmini, xmaxm))
#    plt.yticks(np.arange(0, PFmax+0.1, PFmax/2), fontsize=10); plt.ylim((0 , PFmax));
#    plt.subplots_adjust(left=left, right=right, top=top, bottom=bottom);  
#    plt.gca().add_patch(patches.Rectangle((0, 0), gap,PFmax,alpha = 0.1))
    axes[4].text(xmini+0.05, PFmax*3/4, r'PF($\mu$W/cmK${^2}$)', size = 10, alpha = 1)
  
###############################

##ZT
#    plt.subplot(616)
    for i in range(len(dos_merged)):
        axes[5].plot(energy_axis[i][:], (tau*T*dos_merged[i][:,11]*dos_merged[i][:,11]*dos_merged[i][:,2]/(tau*dos_merged[i][:,20]+klat))/3+(tau*T*dos_merged[i][:,15]*dos_merged[i][:,15]*dos_merged[i][:,6]/(tau*dos_merged[i][:,24]+klaty))/3+(tau*T*dos_merged[i][:,19]*dos_merged[i][:,19]*dos_merged[i][:,10]/(tau*dos_merged[i][:,28]+klatz))/3,color='red')   
    axes[5].set_xlabel("u (eV)", fontsize=11)
#    plt.xticks(np.arange(-10, 10,1),fontsize=10)
    
    ZTmax = axes[5].get_ylim()[1]
    axes[5].set_xlim((xmini, xmaxm))
#    plt.yticks(np.arange(0, ZTmax*1.1, 1), fontsize=10); 
    axes[5].set_ylim(( 0, ZTmax)); 
#    plt.subplots_adjust(left=left, right=right, top=top, bottom=bottom);  

#    plt.gca().add_patch(patches.Rectangle((0, 0), gap,ZTmax,alpha = 0.1))
    axes[5].text(xmini+0.05, ZTmax*3/4, "ZT", size = 10, alpha = 1)

    outfile = 'ZTave.png'
    print()
    print(f"Save figure to [{outfile}]")
    plt.savefig(outfile, dpi=300) #, transparent=True)

    plt.tight_layout()
#    plt.show()
    plt.pause(0.001)
    input("\nPress ENTER to terminate>>\n")
