# Originaly written by He Xinyi
# Modified by T. Kamiya

import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import sys
from pymatgen.core.structure import Structure
from pymatgen.io.vasp import inputs
from pymatgen.core.periodic_table import get_el_sp
import os
import collections


from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg


#------------------ FONT_setup ----------------------
font = {'family' : None, #'arial', 
    'color'  : 'black',
    'weight' : 'normal',
    'size' : 18.0,
    }

xmin = -2
xmax = 2
ymaxt = 10 #ymax for totalDOS
ymaxp = 2  #ymax for PDOS
ymax =  1  #ymax for selected DOS file

save_figure = 1
plot_figure = 1

app = tkApplication()
xmin        = getfloatarg(1, xmin)
xmax        = getfloatarg(2, xmax)
ymaxt       = getfloatarg(3, ymaxt)
ymaxp       = getfloatarg(4, ymaxp)
ymax        = getfloatarg(5, ymax)
save_figure = getintarg  (6, save_figure)
plot_figure = getintarg  (7, plot_figure)


tkProg_Root   = os.environ.get("tkProg_Root")
tkprog_X_path = os.environ.get("tkprog_X_path", os.path.join(tkProg_Root, 'tkprog_Linux'))
split_script  = os.path.join(tkprog_X_path, 'VASP', 'He', 'gsplitpdos')
sum_script    = os.path.join(tkprog_X_path, 'VASP', 'He', './gsumpdos')


if not plot_figure:
    mpl.use('Agg') #silent mode


print("")
print(f"xmin: {xmin}")
print(f"xmax: {xmax}")
print(f"ymaxt: {ymaxt}")
print(f"ymaxp: {ymaxp}")
print(f"ymax: {ymax}")
print(f"save_figure: {save_figure}")
print(f"plot_figure: {plot_figure}")

#------------------plot TotalDOS------------------------------------
if 1:
#if len(sys.argv[:]) == 1:
    print("Total and PDOS mode")
    command = f"bash {split_script}"
    print("")
    print(f"Execute [{command}]")
    os.system(command) #get TDOS and all PDOS for each atom

    print(" ")
    print(f"Read Total_DOS from TotalDOS")
    with open('TotalDOS',"r") as reader:
    	legend = reader.readline()
    legends=legend.split()[1:]
    legends=[i.replace("_"," ") for i in legends]
    legend_s=tuple(legends[:1])
    datas=np.loadtxt('TotalDOS',dtype=np.float64,skiprows=1)
    energy_axis = []
    tdos = []
    energy_axis.append(datas[:, 0])
    tdos.append(datas[:, 1:])

    fig = plt.figure(figsize = (12, 8))
    axe = fig.add_subplot(1, 1, 1)
#    axe = plt.subplot(111)
    plt.subplots_adjust(left=0.18, right=0.95, top=0.95, bottom=0.18)
    for i in range(len(tdos)):
        plt.plot(energy_axis[i][:], tdos[i][:,0])
    axe.set_xlabel(r'${E}$-$E_{F}$ (eV)',fontdict=font)
    axe.set_ylabel(r'TDOS (states/eV)',fontdict=font)
    my_y_ticks=np.arange(0, ymaxt+1,ymaxt/2)
    my_x_ticks=np.arange(-5, 5,1)
    plt.xticks(my_x_ticks,fontsize=font['size']-2)
    plt.yticks(my_y_ticks,fontsize=font['size']-2)
    plt.xlim(( xmin,  xmax)) 
    plt.ylim(( 0,  ymaxt))
    plt.title(f'Total DOS')
    plt.legend(legend_s,loc='upper right')
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()
    plt.setp(ltext, fontsize=font['size']-6) 
    fig = plt.gcf()
    fig.set_size_inches(5, 4)
    plt.tight_layout()

    if save_figure:
        print(f"save Total_DOS fig: TotalDOS.png")
        plt.savefig('TotalDOS.png',dpi= 300)

    if plot_figure:
        print(f"plot Total DOS (TotalDOS.png)")
        plt.pause(0.0001)

#    plt.close()

#------------------get PDOS for each element-----------------
    print("")
    print("Read structure from POSCAR")
    structure = Structure.from_file("POSCAR")

    def gen_species_dictionary(atomic_number_uniq):
        species_dict = {}
        counter = 1
        for num in atomic_number_uniq:
            species_dict[num] = counter
            counter += 1
        return species_dict

    atomic_numbers_uniq = list(
        collections.OrderedDict.fromkeys(structure.atomic_numbers))
    species_index = gen_species_dictionary(atomic_numbers_uniq)
    print(f"atomic_numbers_uniq: {atomic_numbers_uniq}")
    print(f"species_index: {species_index}")
    
    for num in atomic_numbers_uniq:
        ele=1+np.where(np.array(structure.atomic_numbers) == num)[0]
        str_ele=""
        for el in ele:
           str_ele += str(el) + " "
        command = (f"bash %s %s %s" %(sum_script, str(get_el_sp(num)), str_ele))  #%s
        print("")
        print(f"Execute [{command}]")
        os.system(command)
    
#------------------plot PDOS------------------------------------
    print("")

    str_spec  = []
    str_spec2 = []
    for num in atomic_numbers_uniq:
        str_spec.append(get_el_sp(num))
        spec=get_el_sp(num)
        str_spec2.append(str(spec))

    i=0
    while i < len(str_spec):
        file = "PDOS_" + str(str_spec2[i])
        print(f"Read PDOS from [{file}]")
        with open(file,"r") as reader:
        	legend = reader.readline()
        legends=legend.split()[1:]
        legends=[i.replace("_"," ") for i in legends]
        legend_s=tuple(legends)
        datas=np.loadtxt(file,dtype=np.float64,skiprows=1)

        fig = plt.figure(figsize = (12, 8))
        axe = fig.add_subplot(1, 1, 1)
#        axe = plt.subplot(111)
        plt.subplots_adjust(left=0.18, right=0.95, top=0.95, bottom=0.18) 
        axe.plot(datas[:,0],datas[:,1:],linewidth=1.0)  
        axe.set_xlabel(r'${E}$-$E_{F}$ (eV)',fontdict=font)
        axe.set_ylabel(r'PDOS (states/eV)',fontdict=font)
        my_y_ticks=np.arange(0, ymaxp+1,ymaxp/2)
        my_x_ticks=np.arange(-5, 5,1)
        plt.xticks(my_x_ticks,fontsize=font['size']-2)
        plt.yticks(my_y_ticks,fontsize=font['size']-2) 
        plt.xlim(( xmin,  xmax)) 
        plt.ylim(( 0,  ymaxp))
        plt.title(f'PDOS {file}')
        plt.legend(legend_s,loc='upper right')
        leg = plt.gca().get_legend()
        ltext = leg.get_texts()
        plt.setp(ltext, fontsize=font['size']-6) 
        fig = plt.gcf()
        fig.set_size_inches(5, 4)
        plt.tight_layout()

        outputfile = f'{file}.png'
        if save_figure:
            print(f"save PDOS fig: {outputfile}")
            plt.savefig(outputfile, dpi = 300)

        if plot_figure:
            print("")
            print(f"plot PDOS ({outputfile})")
            plt.pause(0.0001)

#        plt.close()

        i=i+1

if len(sys.argv[:]) == 2:
    file = sys.argv[1]
    print("")
    print("Total DOS mode")
    print(f"Read DOS from [{file}]")
    with open(file,"r") as reader:
    	legend = reader.readline()
    legends=legend.split()[1:]
    legends=[i.replace("_"," ") for i in legends]
    legend_s=tuple(legends)
    datas=np.loadtxt(file,dtype=np.float64,skiprows=1)

    fig = plt.figure(figsize = (12, 8))
    axe = fig.add_subplot(1, 1, 1)
#    axe = plt.subplot(111)
    plt.subplots_adjust(left=0.18, right=0.95, top=0.95, bottom=0.18) 
    axe.plot(datas[:,0],datas[:,1:],linewidth=1.0)  
    axe.set_xlabel(r'${E}$-$E_{F}$ (eV)',fontdict=font)
    axe.set_ylabel(r'PDOS (states/eV)',fontdict=font)
    my_y_ticks=np.arange(0, ymax+1,ymax/2)
    my_x_ticks=np.arange(-5, 5,1)
    plt.xticks(my_x_ticks,fontsize=font['size']-2)
    plt.yticks(my_y_ticks,fontsize=font['size']-2) 
    plt.xlim(( xmin,  xmax)) 
    plt.ylim(( 0,  ymax))
    plt.legend(legend_s,loc='upper right')
    leg = plt.gca().get_legend()
    ltext = leg.get_texts()
    plt.setp(ltext, fontsize=font['size']-6) 
    fig = plt.gcf()
    fig.set_size_inches(5, 4)
    plt.tight_layout()

    outputfile = f'{file}.png'
    if save_figure:
        print(f"save DOS fig: {outputfile}")
        plt.savefig(outputfile, dpi = 300)

    if plot_figure:
        print(f"plot DOS ({outputfile})")
        plt.pause(0.0001)

#    plt.close()

#plt.close()
app.terminate(pause = True)

    
    