# Originaly written by He Xinyi
# Modified by T. Kamiya
# 2025//6/10 Modified by T. Kamiya using .get_plot() to alter font sizes

import numpy
import matplotlib.pyplot as plt
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSDOSPlotter, BSPlotter, BSPlotterProjected, DosPlotter

from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg

vb_energy_range = 4.0
cb_energy_range = 4.0
egrid_interval  = 1
output_png_path = 'band.png'
save_figure = 1
plot_figure = 1
fontsize = 16
fontsize_legend = 12

app    = tkApplication()
vb_energy_range = getfloatarg(1, vb_energy_range)
cb_energy_range = getfloatarg(2, cb_energy_range)
egrid_interval  = getintarg  (3, egrid_interval)
output_png_path = getintarg  (4, output_png_path)
save_figure     = getintarg  (5, save_figure)
plot_figure     = getintarg  (6, plot_figure)
fontsize        = getintarg  (7, fontsize)
fontsize_legend = getintarg  (8, fontsize_legend)


dos_vasprun = Vasprun("./vasprun.xml")
dos_data    = dos_vasprun.complete_dos
bs_vasprun  = Vasprun("./vasprun.xml",parse_projected_eigen=True)
bs_data     = bs_vasprun.get_band_structure(line_mode=1)


ploter = BSDOSPlotter(dos_projection = 'elements', bs_projection = 'elements')
axes_list = ploter.get_plot(bs = bs_data,dos = dos_data)


"""
ploter = BSPlotterProjected(bs = bs_data)
ploter.get_elt_projected_plots(zero_to_efermi = True, ylim = None, \
                                    vbm_cbm_marker = False)
"""

"""
ploter = BSDOSPlotter(bs_projection = 'elements', dos_projection = None,
    #fig_size=(6, 8),
    vb_energy_range = vb_energy_range,
    cb_energy_range = cb_energy_range,
    egrid_interval = egrid_interval,
    bs_legend = 'best')

ploter.get_plot(bs = bs_data, dos = None)
"""

#fig = axes.get_figure()
fig = plt.gcf()   # get current figure
#fig.set_size_inches(3, 3)
#fig.set_figwidth(12)   # 幅のみ 12 インチ
#fig.set_figheight(4)   # 高さのみ 4 インチ
#fig.set_dpi(150)
#axes_list = plt.gca()  # get current axes
#print("axes=", axes)
for axes in axes_list:
    axes.xaxis.label.set_fontsize(fontsize) 
    axes.yaxis.label.set_fontsize(fontsize)
#   axes.set_xlabel("Wavevector", fontsize=fontsize)
#   axes.set_ylabel("Energy (eV)", fontsize=fontsize)
    axes.tick_params(axis="both", which="major", labelsize=fontsize)     # フォントサイズ
    axes.tick_params(axis="both", which="minor", labelsize=fontsize)
#   axes.lines[0].set_color("red")
#   lines = axes.get_lines()
#   labels = [line.get_label() for line in lines]
#   axes.legend(lines, labels, fontsize=fontsize_legend)
    leg = axes.get_legend()
    if leg is not None:
        for text in leg.get_texts():
            text.set_fontsize(fontsize_legend)

#plt.rc("xtick", labelsize=fontsize)
#plt.rc("ytick", labelsize=fontsize)
#plt.rcParams["legend.fontsize"] = fontsize_legend


if save_figure:
    plt.savefig('band.png')

if plot_figure:
    plt.pause(0.0001)

app.terminate(pause = True)

