# Originaly written by He Xinyi
# Modified by T. Kamiya
# 2025//6/10 Modified by T. Kamiya using .get_plot() to alter font sizes

import numpy
import matplotlib.pyplot as plt
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.electronic_structure.plotter import BSDOSPlotter, BSPlotter, BSPlotterProjected, DosPlotter

from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatarg

vb_energy_range = 4.0
cb_energy_range = 4.0
egrid_interval  = 1
output_png_path = 'band.png'
save_figure = 1
plot_figure = 1
fontsize = 16
fontsize_legend = 12

app    = tkApplication()
vb_energy_range = getfloatarg(1, vb_energy_range)
cb_energy_range = getfloatarg(2, cb_energy_range)
egrid_interval  = getintarg  (3, egrid_interval)
output_png_path = getintarg  (4, output_png_path)
save_figure     = getintarg  (5, save_figure)
plot_figure     = getintarg  (6, plot_figure)
fontsize        = getintarg  (7, fontsize)
fontsize_legend = getintarg  (8, fontsize_legend)


dos_vasprun = Vasprun("./vasprun.xml")
dos_data    = dos_vasprun.complete_dos
bs_vasprun  = Vasprun("./vasprun.xml",parse_projected_eigen=True)
bs_data     = bs_vasprun.get_band_structure(line_mode=1)


plotter = BSDOSPlotter(dos_projection = 'orbitals', bs_projection = 'elements') 
axes = plotter.get_plot(bs = bs_data,dos = dos_data)


"""
plotter = BSPlotterProjected(bs = bs_data)
plotter.get_elt_projected_plots(zero_to_efermi = True, ylim = None, \
                                    vbm_cbm_marker = False)
"""

"""
plt = BSDOSPlotter(bs_projection = 'elements', dos_projection = None,
    #fig_size=(6, 8),
    vb_energy_range = vb_energy_range,
    cb_energy_range = cb_energy_range,
    egrid_interval = egrid_interval,
    bs_legend = 'best')

plt.get_plot(bs = bs_data, dos = None)
"""

fig = plt.gcf()   # get current figure
axes = plt.gca()

fig_nums = plt.get_fignums()      # e.g. [1, 2, 3]
all_figs = [plt.figure(num) for num in fig_nums]
all_axes = []
for fig in all_figs:
    # fig.axes はその Figure に属する Axes オブジェクトのリスト
    all_axes.extend(fig.axes)

# 確認
print("Figures:", all_figs)
print("Total axes:", len(all_axes))
for iax, ax in enumerate(all_axes):
    print("  ", ax, "  nlines=", len(ax.get_lines()), "  nscatters=", len(ax.collections))
    if len(ax.get_lines()) + len(ax.collections) == 0:
        ax.remove()
#        ax.set_visible(False)
    else:
        ax.xaxis.label.set_fontsize(fontsize)
        ax.yaxis.label.set_fontsize(fontsize)
        ax.tick_params(axis="both", which="major", labelsize=fontsize)     # フォントサイズ
        ax.tick_params(axis="both", which="minor", labelsize=fontsize)
        if iax == 1:
            lines = ax.get_lines()
            labels = [line.get_label() for line in lines]
            ax.legend(lines, labels, fontsize=fontsize_legend)

"""
#fig.set_size_inches(3, 3)
#fig.set_figwidth(12)   # 幅のみ 12 インチ
#fig.set_figheight(4)   # 高さのみ 4 インチ
#fig.set_dpi(150)
axes.xaxis.label.set_fontsize(fontsize)
axes.yaxis.label.set_fontsize(fontsize)
#axes.set_xlabel("Wavevector", fontsize=fontsize)
#axes.set_ylabel("Energy (eV)", fontsize=fontsize)
axes.tick_params(axis="both", which="major", labelsize=fontsize)     # フォントサイズ
axes.tick_params(axis="both", which="minor", labelsize=fontsize)
#axes.lines[0].set_color("red")
#lines = axes.get_lines()
#labels = [line.get_label() for line in lines]
#axes.legend(lines, labels, fontsize=fontsize_legend)
leg = axes.get_legend()
if leg is not None:
    for text in leg.get_texts():
        text.set_fontsize(fontsize_legend)

#plt.rc("xtick", labelsize=fontsize)
#plt.rc("ytick", labelsize=fontsize)
#plt.rcParams["legend.fontsize"] = fontsize_legend
"""

if save_figure:
    plt.savefig('band.png')

if plot_figure:
    plt.pause(0.0001)

app.terminate(pause = True)

