#!/usr/bin/python
# -*- coding:utf-8 -*-
#for all bands: python bandi.py 
#for select bands:python bandi.py bandnumber1 bandnumber2 bandnumber3 ... 
#used in vaspkit output 
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
import numpy as np
import sys


from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg



#------------------ FONT_setup ----------------------
title = " "
colortotal = 'black'
colormaps  = ['black', 'r', 'g','blue', 'm', 'k', 'c', 'y', 'r']
#['black','blue','red','green']
ncolors = len(colormaps)

font = {'family' : 'arial', 
    'color'  : 'black',
    'weight' : 'normal',
    'size' : 18,
    }



ymin = -6.0
ymax =  6.0
save_figure = 1
plot_figure = 1

app = tkApplication()
nargs = len(sys.argv[:])

ymin        = getfloatarg(1, ymin)
ymax        = getfloatarg(2, ymax)
save_figure = getintarg  (3, save_figure)
plot_figure = getintarg  (4, plot_figure)
# argv[5:] : band indexes


#------------------- Data Read ----------------------
Greek_alphabets=['Alpha','Beta','Gamma','Delta','Epsilon','Zeta','Eta','Theta', 'Iota','Kappa','Lambda','Mu','Nu','Xi','Omicron','Pi','Rho','Sigma','Tau','Upsilon','Phi','Chi','Psi','Pega']
group_labels=[];xtick=[]
with open('KLABELS','r') as reader:
    lines=reader.readlines()[1:]
for i in lines:
    s=i.encode('utf-8')#.decode('latin-1')
    if len(s.split())==2 and not s.decode('utf-8','ignore').startswith('*'):
        klabel=str(s.decode('utf-8','ignore').split()[0])
        for j in range(len(Greek_alphabets)):
            if (klabel.find(Greek_alphabets[j].upper())>=0):
                latex_exp=r''+'$\\'+str(Greek_alphabets[j])+'$'
                klabel=klabel.replace(str(Greek_alphabets[j].upper()),str(latex_exp))
        if (klabel.find('_')>0):
           n=klabel.find('_')
           klabel=klabel[:n]+'$'+klabel[n:n+2]+'$'+klabel[n+2:]
        group_labels.append(klabel)
        xtick.append(float(s.split()[1]))
datas=np.loadtxt('REFORMATTED_BAND.dat',dtype=np.float64)

#--------------------- PLOTs ------------------------
axe = plt.subplot(111)
axe.axhline(y=0, xmin=0, xmax=1,linestyle= '--',linewidth=0.5,color='0.5')
for i in xtick[1:-1]:
    axe.axvline(x=i, ymin=0, ymax=1,linestyle= '--',linewidth=0.5,color='0.5')

if nargs <= 5:
    axe.plot(datas[:,0], datas[:,1:], linewidth = 1.0, color = colortotal)

else:
    indexes = []
    for i in range(5, nargs):
        try:
            indexes.append(int(sys.argv[i]))
        except:
            a = sys.argv[i].split('-')
            i0 = int(a[0])
            i1 = int(a[1])
            for i in range(i0, i1+1):
                indexes.append(i)

    for i in range(len(indexes)):
        bandi = indexes[i]
        if bandi <= 0:
            continue

        icolor = (i - 1) % ncolors
        if bandi < len(datas[0]):
            axe.plot(datas[:, 0], datas[:, bandi], linewidth = 1.0, color = colormaps[icolor])

axe.set_ylabel(r'${E}$-$E_{F}$ (eV)',fontdict=font,fontsize=font['size'])
axe.set_xticks(xtick)
plt.yticks(fontsize=font['size']-2,fontname=font['family'])
axe.set_xticklabels(group_labels, rotation=0,fontsize=font['size']-2,fontname=font['family'])
axe.set_xlim((xtick[0], xtick[-1]))
plt.ylim((ymin, ymax)) # set y limits manually
fig = plt.gcf()
fig.set_size_inches(4, 5)
plt.title(title,size=22,pad=25)
plt.subplots_adjust(left=0.25, right=0.95, top=0.9, bottom=0.12)

plt.minorticks_on()
plt.tick_params(top=False, bottom=True, left=True,right=False) #, which='both')
 #   plt.subplots_adjust(left=0.18)
plt.tick_params(axis='y',which='major',direction='out', length=7, width=0.5, grid_alpha=0.5,pad=3)
plt.tick_params(axis='y',which='minor', direction='out', length=4, width=0.5, grid_alpha=0.5)
plt.tick_params(axis='x',which='minor',direction='in', length=0.0001, width=0.01, grid_alpha=0.5)
plt.tick_params(axis='x',which='major', direction='out', length=5, width=0.5) #, grid_alpha=0.5,pad=10)

outputfile = 'band.png'
if save_figure:
    print("")
    print(f"Save figure to {outputfile}")
    plt.savefig(outputfile, dpi = 300)

if plot_figure:
    print("")
    print(f"plot band structure ({outputfile})")
    plt.pause(0.0001)


app.terminate(pause = True)
