#!/usr/bin/env python
# -*- coding: utf-8 -*-

# ==============================================================================
#
#                           Band & DOS Plotter for VASP
#
# ==============================================================================
#
# --- 機能 ---
#  - VASPの計算結果からバンド構造(band)と状態密度(DOS)を描画します。
#  - --mode オプションにより、'band_dos'(両方), 'band'(バンドのみ), 'dos'(DOSのみ)の
#    3つの描画モードを選択できます。
#  - 各モードに応じて、最適な図のサイズが自動で設定されます。
#  - エネルギーの基準(0 eV)を価電子帯の頂上(VBM)またはバンドギャップの中央(midgap)
#    から選択できます (--reference オプション)。
#  - バンドギャップのある物質では、デフォルトでVBMとCBMに色付きの補助線が描画されます。
#    これを非表示にするには --hide_gap_lines オプションを使用します。
#  - 論文掲載に適した、高品質で見やすいデザインに自動で整形されます。
#  - 豊富なコマンドライン引数により、プロットの様々な要素を柔軟に調整可能です。
#
# --- 使い方 ---
#  1. 下記の「▼▼▼ ユーザー設定エリア ▼▼▼」で、設定値を確認します。
#  2. ターミナルで以下のコマンドを実行します。
#
#     基本コマンド (バンドとDOSをVBM基準でプロットし 'band_dos_vbm.png' に保存):
#     $ python plot_band_dos.py --mode band_dos --save
#
#     応用コマンド (DOSのみをPDOSとしてプロット、y軸は自動で-20~10に設定):
#     $ python plot_band_dos.py --mode dos --elements Si O --show
#
# --- 必要なファイル構成 ---
#  このスクリプトと同じ階層に、以下のディレクトリとファイルを配置してください。
#
#  <project_directory>/
#  ├── plot_band_dos.py   (このスクリプト)
#  ├── BAND/              (名前は下記 `BAND_DIR_NAME_PREFIX` で変更可)
#  │   ├── vasprun.xml    (VASPの出力)
#  │   └── KPOINTS        (バンド計算で使用したもの)
#  └── DOS/               (名前は下記 `DOS_DIR_NAME` で変更可)
#      ├── DOS-up.csv     (スピンアップのDOSデータ)
#      └── DOS-down.csv   (スピンダウンのDOSデータ、スピン分極計算の場合)
#
#  ※ DOSファイルについて:
#    このスクリプトは、DOSファイルのエネルギー基準がバンドギャップの中央
#    (ミッドギャップ)に設定されていることを前提としています。
#    (例: vaspkitで EFERMI = MIDGAP を指定して出力したDOSデータなど)
#
# ==============================================================================


import argparse
import sys
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.ndimage import gaussian_filter1d
from pymatgen.io.vasp import Vasprun
from pymatgen.electronic_structure.plotter import BSPlotter
from pymatgen.electronic_structure.bandstructure import BandStructureSymmLine
import re

# ▼▼▼ ユーザー設定エリア ▼▼▼
# ==============================================================================
# このエリアでは、スクリプトの基本的な挙動やデザインを簡単に変更できます。

# --- 1. ファイルとディレクトリの設定 ---
CURRENT_DIR = Path.cwd()
BAND_DIR_NAME_PREFIX = 'Band'  # バンド計算のディレクトリ名の接頭辞
DOS_DIR_NAME = 'DOS'           # DOS計算のディレクトリ名
DEFAULT_SAVE_NAME = 'plot.png' # デフォルトの保存ファイル名

# --- 2. 計算に関する設定 ---
IS_METAL_THRESHOLD = 0.01  # このバンドギャップ以下の物質を金属と判定 (eV)

# --- 3. プロットのデフォルトスタイル ---
PLOT_STYLE = {
    # フォントサイズ
    "font_size_title": 20, "font_size_labels": 18, "font_size_ticks": 16, "font_size_legend": 14,
    # 線の太さ
    "line_width_axis": 1.2, "line_width_ref_line": 1.5, "line_width_gap_line": 1.2,
    "line_width_band": 1.5, "line_width_kpath": 0.75,
    # マーカーの設定
    "marker_size_vbmcbm": 80, "marker_edge_width": 1.0,
}
# 高対称点ラベルをギリシャ文字に変換するための辞書
GREEK_LABELS = {"GAMMA": "Γ", "DELTA": "Δ", "LAMBDA": "Λ", "SIGMA": "Σ", "THETA": "Θ", "PI": "Π", "PHI": "Φ", "OMEGA": "Ω"}
# ==============================================================================
# ▲▲▲ ユーザー設定エリア ▲▲▲


# --- スクリプトの内部処理で使用するパス設定 (編集不要) ---
try:
    BAND_DIR = next(CURRENT_DIR.glob(f'{BAND_DIR_NAME_PREFIX}*'))
    VASP_FILE, KPOINTS_FILE = BAND_DIR / 'vasprun.xml', BAND_DIR / 'KPOINTS'
except StopIteration:
    VASP_FILE, KPOINTS_FILE = None, None

DOS_DIR = CURRENT_DIR / DOS_DIR_NAME
DOS_UP_FILE, DOS_DOWN_FILE = DOS_DIR / 'DOS-up.csv', DOS_DIR / 'DOS-down.csv'


class BSPlotterWithGreek(BSPlotter):
    """pymatgenのBSPlotterを拡張し、デザイン調整やエネルギーシフト機能を追加したカスタムクラス。"""
    def __init__(self, bs: BandStructureSymmLine):
        super().__init__(bs)

    def plot_on_ax(self, ax, energy_shift=0.0, vbm_cbm_marker=True):
        """指定されたAxesオブジェクト上にバンド構造を描画する。"""
        plot_data = self.bs_plot_data()

        if energy_shift != 0.0:
            for spin in plot_data['energy']:
                for i, segment in enumerate(plot_data['energy'][spin]):
                    plot_data['energy'][spin][i] = segment + energy_shift
            if 'vbm' in plot_data: plot_data['vbm'] = [(d, e + energy_shift) for d, e in plot_data['vbm']]
            if 'cbm' in plot_data: plot_data['cbm'] = [(d, e + energy_shift) for d, e in plot_data['cbm']]

        for side in ['top', 'bottom', 'left', 'right']: ax.spines[side].set_linewidth(PLOT_STYLE["line_width_axis"])
        ax.tick_params(axis='both', which='major', direction='in', top=True, right=True, length=6, width=PLOT_STYLE["line_width_axis"], labelsize=PLOT_STYLE["font_size_ticks"])

        for i, _ in enumerate(plot_data['distances']):
            for spin in plot_data['energy']:
                for band_energies in plot_data['energy'][spin][i]:
                    ax.plot(plot_data['distances'][i], band_energies, color="black", linewidth=PLOT_STYLE["line_width_band"])

        if vbm_cbm_marker:
            vbm_coords, cbm_coords = plot_data.get('vbm', []), plot_data.get('cbm', [])
            if vbm_coords:
                ax.scatter(vbm_coords[0][0], vbm_coords[0][1], color="red", marker="o", s=PLOT_STYLE["marker_size_vbmcbm"], zorder=10, edgecolor='black', linewidth=PLOT_STYLE["marker_edge_width"], label="VBM")
                for i in range(1, len(vbm_coords)): ax.scatter(vbm_coords[i][0], vbm_coords[i][1], color="red", marker="o", s=PLOT_STYLE["marker_size_vbmcbm"], zorder=10, edgecolor='black', linewidth=PLOT_STYLE["marker_edge_width"])
            if cbm_coords:
                ax.scatter(cbm_coords[0][0], cbm_coords[0][1], color="blue", marker="o", s=PLOT_STYLE["marker_size_vbmcbm"], zorder=10, edgecolor='black', linewidth=PLOT_STYLE["marker_edge_width"], label="CBM")
                for i in range(1, len(cbm_coords)): ax.scatter(cbm_coords[i][0], cbm_coords[i][1], color="blue", marker="o", s=PLOT_STYLE["marker_size_vbmcbm"], zorder=10, edgecolor='black', linewidth=PLOT_STYLE["marker_edge_width"])

        ticks_data = plot_data['ticks']
        tick_positions, tick_labels_raw = ticks_data['distance'], ticks_data['label']
        ax.set_xticks(tick_positions)
        ax.set_xticklabels([GREEK_LABELS.get(str(lbl).replace("$", "").upper(), str(lbl)) for lbl in tick_labels_raw])
        for pos in tick_positions: ax.axvline(x=pos, linestyle='--', color='gray', linewidth=PLOT_STYLE["line_width_kpath"])
        ax.set_xlabel("Wave Vector", fontsize=PLOT_STYLE["font_size_labels"])
        ax.set_xlim(min(tick_positions), max(tick_positions))
        return ax


def plot_dos_on_ax(ax, df_up, df_down, energy_shift, args, ymin, ymax, is_shared_y_axis=False):
    """一つのAxesオブジェクトにDOSを描画する。"""
    energies, is_spin_polarized = df_up.iloc[:, 0] + energy_shift, df_down is not None
    delta_e = np.mean(np.diff(df_up.iloc[:, 0])) if len(df_up) > 1 else 0

    def smear_dos(dos_series):
        if args.dos_smearing > 0 and delta_e > 0: return gaussian_filter1d(dos_series, sigma=args.dos_smearing / delta_e)
        return dos_series

    ax.tick_params(axis='both', which='major', direction='in', top=True, right=True, length=6, width=PLOT_STYLE["line_width_axis"], labelsize=PLOT_STYLE["font_size_ticks"])
    for side in ['top', 'bottom', 'left', 'right']: ax.spines[side].set_linewidth(PLOT_STYLE["line_width_axis"])
    ax.set_xlabel("Density of States", fontsize=PLOT_STYLE["font_size_labels"])

    if is_shared_y_axis:
        ax.set_ylabel("")
        plt.setp(ax.get_yticklabels(), visible=False)

    if is_spin_polarized: ax.axvline(0, color='black', linewidth=0.5)

    def plot_spin_pair(dos_up, dos_down, label, **kwargs):
        smeared_up = smear_dos(dos_up)
        if is_spin_polarized and dos_down is not None:
            smeared_down = smear_dos(dos_down)
            line, = ax.plot(smeared_up, energies, label=f'{label} (Up)', **kwargs)
            ax.plot(-smeared_down, energies, color=line.get_color(), **{k:v for k,v in kwargs.items() if k != 'label'})
        else:
            ax.plot(smeared_up, energies, label=label, **kwargs)

    if args.total_dos:
        total_cols = [col for col in df_up.columns if 'total' in col.lower()]
        if total_cols: plot_spin_pair(df_up[total_cols].sum(axis=1), df_down[total_cols].sum(axis=1) if is_spin_polarized else None, 'Total', color='black')

    if args.elements:
        available_elements = {col.split()[0] for col in df_up.columns if len(col.split()) > 1}
        for el in args.elements:
            if el not in available_elements: print(f"⚠️  警告: 元素 '{el}' のデータが見つかりません。"); continue
            if args.orbitals:
                for orb in args.orbitals:
                    col_name = f"{el} {orb}"
                    if col_name in df_up.columns: plot_spin_pair(df_up[col_name], df_down.get(col_name) if is_spin_polarized else None, f"{el} ({orb})")
            else:
                col_name = f"{el} total"
                if col_name in df_up.columns: plot_spin_pair(df_up[col_name], df_down.get(col_name) if is_spin_polarized else None, f"{el}")

    # xmaxが指定されていない場合、表示範囲内のDOSの最大値に基づいて自動で設定
    if args.xmax is not None:
        final_xmax = args.xmax
    else:
        max_dos_in_view = 0
        if ax.lines:
            # yminとymaxの範囲内にあるDOSの最大値を取得
            dos_values_in_view = [
                np.max(np.abs(line.get_xdata()[(line.get_ydata() >= ymin) & (line.get_ydata() <= ymax)]))
                for line in ax.lines if np.any((line.get_ydata() >= ymin) & (line.get_ydata() <= ymax))
            ]
            if dos_values_in_view:
                max_dos_in_view = max(dos_values_in_view)

        # 最大値に10%の余白を持たせる。最大値が0ならデフォルト値(10)を設定
        final_xmax = max_dos_in_view * 1.1 if max_dos_in_view > 0 else 10

    if args.xmin is not None:
        final_xmin = args.xmin
    elif not is_spin_polarized:
        final_xmin = 0
    else:
        # スピン分極計算の場合、xminはxmaxの負の値とする
        final_xmin = -final_xmax
    ax.set_xlim(final_xmin, final_xmax)

    handles, labels = ax.get_legend_handles_labels()
    if labels:
        if is_spin_polarized:
            filtered_labels, filtered_handles = [l for l in labels if '(Down)' not in l], [h for h, l in zip(handles, labels) if '(Down)' not in l]
            ax.legend(filtered_handles, [l.replace(' (Up)', '') for l in filtered_labels], fontsize=PLOT_STYLE["font_size_legend"], frameon=False)
        else:
            ax.legend(fontsize=PLOT_STYLE["font_size_legend"], frameon=False)
    return ax

def parse_arguments():
    """コマンドラインからの引数を解析し、ヘルプメッセージを生成する。"""
    parser = argparse.ArgumentParser(description="VASP計算結果からバンド構造とDOSを描画するスクリプト。", formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    g_mode = parser.add_argument_group('描画モードの設定')
    g_mode.add_argument("--mode", type=str, default="band_dos", choices=["band_dos", "band", "dos"], help="描画モードを選択: 'band_dos' (両方), 'band' (バンド構造のみ), 'dos' (DOSのみ)。")

    g_display = parser.add_argument_group('表示・保存に関する設定')
    g_display.add_argument("--reference", type=str, default="vbm", choices=["vbm", "midgap"], help="プロットのエネルギー基準(y=0)を 'vbm' (価電子帯頂上) または 'midgap' (バンドギャップ中央) に設定します。")
    g_display.add_argument("--ymin", type=float, default=-5.0, help="y軸(エネルギー)の最小値 (eV)")
    g_display.add_argument("--ymax", type=float, default=9.0, help="y軸(エネルギー)の最大値 (eV)")
    g_display.add_argument("--title", type=str, help="図全体のタイトル")
    g_display.add_argument("--width", type=float, default=10, help="図全体の幅 (インチ)。--modeに応じて自動設定: band_dos(10), band(6), dos(6)")
    g_display.add_argument("--height", type=float, default=6, help="図全体の高さ (インチ)。--modeに応じて自動設定: band_dos(6), band(5), dos(6)")
    g_display.add_argument("--save", nargs="?", const=DEFAULT_SAVE_NAME, default=None, help=f"画像をファイルに保存します。ファイル名を省略した場合は 'mode_reference.png' の形式となります。")
    g_display.add_argument("--show", action="store_true", help="画面にプロットを表示します。")
    g_display.add_argument("--band_dos_ratio", nargs=2, type=int, default=[3, 1], metavar=('BAND', 'DOS'), help="バンドとDOSの横方向の比率 (--mode band_dos の場合のみ有効)。")

    g_band = parser.add_argument_group('バンド構造プロットに関する設定')
    g_band.add_argument("--no-vbmcbm", action="store_true", help="VBM/CBMのマーカー(赤丸/青丸)を非表示にします。")
    g_band.add_argument("--band_legend", action="store_true", help="バンド図に凡例(VBM/CBM)を表示します。")
    g_band.add_argument("--hide_gap_lines", dest="add_gap_lines", action="store_false", help="デフォルトで表示されるVBMとCBMの色付き点線を非表示にし、代わりにy=0の黒い基準線を表示します。")

    g_dos = parser.add_argument_group('DOSプロットに関する設定')
    g_dos.add_argument("--dos-smearing", type=float, default=0.05, help="DOSのガウシアンスメアリングのσ値 (eV, 0=適用しない)。")
    g_dos.add_argument("--xmin", type=float, default=None, help="DOSプロットのx軸の最小値を指定します。")
    g_dos.add_argument("--xmax", type=float, default=None, help="DOSプロットのx軸の最大値を指定します。")
    g_dos.add_argument("--elements", nargs="+", metavar="El", help="射影DOS(PDOS)を描画する元素を指定します (例: Si O)。")
    g_dos.add_argument("--orbitals", nargs="+", metavar="orb", help="--elements と併用し、指定元素の軌道(s, p, d, f)を描画します。")
    g_dos.add_argument("--no-total-dos", dest="total_dos", action="store_false", help="Total DOSを描画しません。")
    return parser.parse_args()

def add_reference_lines(ax, args, is_metal, band_gap):
    """プロットにエネルギー基準線やギャップ線を追加するヘルパー関数。"""
    if args.add_gap_lines and not is_metal:
        vbm_y, cbm_y = (0.0, band_gap) if args.reference == 'vbm' else (-band_gap / 2.0, band_gap / 2.0)
        ax.axhline(y=vbm_y, color="red", linestyle=":", linewidth=PLOT_STYLE["line_width_gap_line"])
        ax.axhline(y=cbm_y, color="blue", linestyle=":", linewidth=PLOT_STYLE["line_width_gap_line"])
    else:
        ax.axhline(y=0, color="black", linestyle="--", linewidth=PLOT_STYLE["line_width_ref_line"])

def _add_band_legend(ax):
    """バンドプロットにVBM/CBMの凡例を、より堅牢な方法で追加する。"""
    handles, labels = ax.get_legend_handles_labels()
    # 凡例に追加したい要素のラベルとハンドルを辞書に格納
    legend_elements = {label: handle for handle, label in zip(handles, labels) if label in ["VBM", "CBM"]}
    if legend_elements:
        # 'VBM'を先頭に表示するために順序を制御
        sorted_labels = sorted(legend_elements.keys(), key=lambda x: (x != 'VBM', x))
        sorted_handles = [legend_elements[label] for label in sorted_labels]
        ax.legend(
            handles=sorted_handles,
            labels=sorted_labels,
            fontsize=PLOT_STYLE["font_size_legend"],
            frameon=False,
            loc='upper right'
        )

def format_title_with_subscripts(title_str: str) -> str:
    """
    タイトル文字列内の化学式を、数字が下付き文字になるようにフォーマットする。
    例: "OPDOS in SnO2" は "OPDOS in SnO$_2$" に変換される。
    """
    if not title_str:
        return ""
    # 文字列内の数字を探し、LaTeXの下付き文字形式に置換する
    # スペースは維持されるため、"OPDOS in SnO2" のようなタイトルも正しく扱える
    return re.sub(r'(\d+)', r'$_{\1}$', title_str)

def main():
    """メイン実行関数"""
    args = parse_arguments()

    # --- モードに応じた設定調整 ---
    # `dos` モードの時のみ、y軸のデフォルト範囲をPDOS用に変更する
    if args.mode == 'dos':
        # --ymin と --ymax の argparseでのデフォルト値
        ymin_default = -5.0
        ymax_default = 9.0
        # ユーザーが引数で値を指定していない場合のみ、PDOS用のデフォルト値に上書きする
        if args.ymin == ymin_default:
            args.ymin = -20.0
            print(f"ℹ️  DOSモードを検出: --yminが指定されていないため、デフォルト値を {args.ymin} eV に変更します。")
        if args.ymax == ymax_default:
            args.ymax = 10.0
            print(f"ℹ️  DOSモードを検出: --ymaxが指定されていないため、デフォルト値を {args.ymax} eV に変更します。")

    # ユーザーが図のサイズを明示的に指定していない場合、モードに応じて調整
    user_set_width = args.width != 10.0
    user_set_height = args.height != 6.0
    if args.mode == 'band' and not user_set_width and not user_set_height:
        args.width, args.height = 6.0, 5.0
    elif args.mode == 'dos' and not user_set_width and not user_set_height:
        args.width, args.height = 6.0, 6.0

    # --- ファイル存在チェック ---
    is_dos_mode = args.mode in ['band_dos', 'dos']
    if not VASP_FILE or not VASP_FILE.exists():
        print(f"❌ エラー: vasprun.xmlファイルが見つかりません: {VASP_FILE or 'BAND*/vasprun.xml'}")
        sys.exit(1)
    if not KPOINTS_FILE or not KPOINTS_FILE.exists():
        print(f"❌ エラー: KPOINTSファイルが見つかりません: {KPOINTS_FILE}")
        sys.exit(1)
    if is_dos_mode and not DOS_UP_FILE.exists():
        print(f"❌ エラー: DOSファイルが見つかりません: {DOS_UP_FILE}")
        sys.exit(1)

    # --- データ読み込みとエネルギー基準の計算 ---
    df_up, df_down = None, None
    try:
        print("✅ バンド関連データを読み込み中...")
        vasprun = Vasprun(VASP_FILE, parse_projected_eigen=True)
        band = vasprun.get_band_structure(kpoints_filename=KPOINTS_FILE, line_mode=True)
        band_gap_info = band.get_band_gap()
        band_gap = band_gap_info['energy'] if band_gap_info['energy'] is not None else 0.0
        is_metal = band_gap <= IS_METAL_THRESHOLD

        if is_dos_mode:
            print("✅ DOSデータを読み込み中...")
            df_up = pd.read_csv(DOS_UP_FILE, skipinitialspace=True)
            df_down = pd.read_csv(DOS_DOWN_FILE, skipinitialspace=True) if DOS_DOWN_FILE.exists() else None

        print("✅ エネルギー基準を計算中...")
        # 凡例のラベルを設定
        y_axis_label = r"$E - E_{F}$ (eV)" if is_metal else r"$E - E_{VBM}$ (eV)" if args.reference == 'vbm' else r"$E - E_{midgap}$ (eV)"
        
        # --- エネルギーシフト量の計算 ---
        # バンド構造: pymatgenにより、VBMが0 eVに設定されている。
        # DOSデータ: スクリプトの前提として、ミッドギャップが0 eVに設定されている。
        # この2つの基準のズレを補正し、--referenceで指定された最終的な基準に合わせる。
        if args.reference == 'vbm':
            # 最終基準: VBM=0 eV
            # -> バンド: そのまま (シフト 0)
            # -> DOS: ミッドギャップ基準からVBM基準へ。DOSのVBMは-gap/2なので、+gap/2シフトする。
            band_energy_shift = 0.0
            dos_energy_shift = 0.0 if is_metal else band_gap / 2.0
            print(f"   -> 基準を VBM に設定しました。 (Band Gap = {band_gap:.3f} eV)")
        else:  # reference == 'midgap'
            # 最終基準: ミッドギャップ=0 eV
            # -> バンド: VBM基準からミッドギャップ基準へ。バンドのミッドギャップは+gap/2なので、-gap/2シフトする。
            # -> DOS: そのまま (シフト 0)
            band_energy_shift = 0.0 if is_metal else -band_gap / 2.0
            dos_energy_shift = 0.0
            if is_metal:
                print("   -> 金属のため、実質的にVBM(Ef)基準でプロットします。")
            else:
                print(f"   -> 基準をミッドギャップに設定しました。 (Band Gap = {band_gap:.3f} eV)")

    except Exception as e:
        print(f"❌ ファイル読み込み・解析中にエラーが発生しました: {e}"), sys.exit(1)

    # --- プロット生成 ---
    print("✅ プロットを生成中...")
    fig = plt.figure(figsize=(args.width, args.height))

    if args.mode == 'band_dos':
        gs = gridspec.GridSpec(nrows=1, ncols=2, width_ratios=args.band_dos_ratio, wspace=0.05)
        ax_band = fig.add_subplot(gs[0, 0])
        ax_dos = fig.add_subplot(gs[0, 1], sharey=ax_band)

        # Band plot
        plotter = BSPlotterWithGreek(band)
        plotter.plot_on_ax(ax_band, energy_shift=band_energy_shift, vbm_cbm_marker=not args.no_vbmcbm)
        ax_band.set_ylabel(y_axis_label, fontsize=PLOT_STYLE["font_size_labels"])
        if args.band_legend: _add_band_legend(ax_band)

        # DOS plot
        plot_dos_on_ax(ax_dos, df_up, df_down, energy_shift=dos_energy_shift, args=args, ymin=args.ymin, ymax=args.ymax, is_shared_y_axis=True)

        # Common settings
        ax_band.set_ylim(args.ymin, args.ymax)
        add_reference_lines(ax_band, args, is_metal, band_gap)
        add_reference_lines(ax_dos, args, is_metal, band_gap)
        
        # band_dosモードの時のみ、図全体のタイトル(suptitle)を設定
        if args.title:
            formatted_title = format_title_with_subscripts(args.title)
            fig.suptitle(formatted_title, fontsize=PLOT_STYLE["font_size_title"])


    elif args.mode == 'band':
        ax_band = fig.add_subplot(1, 1, 1)
        plotter = BSPlotterWithGreek(band)
        plotter.plot_on_ax(ax_band, energy_shift=band_energy_shift, vbm_cbm_marker=not args.no_vbmcbm)
        ax_band.set_ylabel(y_axis_label, fontsize=PLOT_STYLE["font_size_labels"])
        if args.band_legend: _add_band_legend(ax_band)

        ax_band.set_ylim(args.ymin, args.ymax)
        add_reference_lines(ax_band, args, is_metal, band_gap)
        
        # bandモードの場合、プロットエリアに直接タイトルを設定
        if args.title:
            formatted_title = format_title_with_subscripts(args.title)
            ax_band.set_title(formatted_title, fontsize=PLOT_STYLE["font_size_title"], y=1.02)


    elif args.mode == 'dos':
        ax_dos = fig.add_subplot(1, 1, 1)
        plot_dos_on_ax(ax_dos, df_up, df_down, energy_shift=dos_energy_shift, args=args, ymin=args.ymin, ymax=args.ymax)

        ax_dos.set_ylabel(y_axis_label, fontsize=PLOT_STYLE["font_size_labels"])
        ax_dos.set_ylim(args.ymin, args.ymax)
        add_reference_lines(ax_dos, args, is_metal, band_gap)

        # dosモードの場合、プロットエリアに直接タイトルを設定
        if args.title:
            formatted_title = format_title_with_subscripts(args.title)
            ax_dos.set_title(formatted_title, fontsize=PLOT_STYLE["font_size_title"], y=1.02)

    # --- レイアウト調整、保存、表示 ---
    # band_dosモードでタイトルがある場合のみ、レイアウト用の矩形(rect)を指定
    if args.mode == 'band_dos' and args.title:
        plt.tight_layout(rect=[0, 0, 1, 0.96])
    else:
        # 他のモードでは自動でレイアウトを調整
        plt.tight_layout()

    if args.save:
        try:
            save_name = f"{args.mode}_{args.reference}.png" if args.save == DEFAULT_SAVE_NAME else args.save
            plt.savefig(save_name, dpi=300, bbox_inches='tight', pad_inches=0.2)
            print(f"✅ 画像を保存しました: {save_name}")
        except Exception as e:
            print(f"❌ 保存エラー: {e}")

    if args.show or not args.save:
        plt.show()

    if not args.show and args.save:
        plt.close(fig)

if __name__ == "__main__":
    main()