import os
import argparse
import types
import pandas as pd
from openpyxl import load_workbook
from openpyxl.chart import ScatterChart, Series, Reference


def initialize():
    """
    Parse command-line arguments and return a SimpleNamespace `cfg` containing:
      - work_dir: 作業ディレクトリ
      - Emax: バンド図のエネルギー最大値
      - Emin: バンド図のエネルギー最小値
      - file paths: template (.xlsm) と出力ファイルパスなど
    """
    parser = argparse.ArgumentParser(description="VASPバンド構造解析とExcel出力")
    parser.add_argument(
        "--work_dir", 
        type=str, 
        default=".", 
        help="作業ディレクトリ (デフォルト: カレントディレクトリ)"
    )
    parser.add_argument(
        "--Emax", 
        type=float, 
        default=9.0, 
        help="バンド図のエネルギー範囲 最大値 (デフォルト: 9.0)"
    )
    parser.add_argument(
        "--Emin", 
        type=float, 
        default=-5.0, 
        help="バンド図のエネルギー範囲 最小値 (デフォルト: -5.0)"
    )
    parser.add_argument(
        "--template", 
        type=str, 
        default="StandardGraph.xlsm", 
        help="マクロ付きテンプレートファイル (.xlsm) の名前 (デフォルト: combined_output.xlsm)"
    )
    args = parser.parse_args()

    cfg = types.SimpleNamespace()
    cfg.work_dir = args.work_dir
    cfg.Emax = args.Emax
    cfg.Emin = args.Emin

    # テンプレートファイルと出力ファイルを設定
    cfg.template = os.path.join(cfg.work_dir, args.template)
    cfg.output_excel = cfg.template

    # 入力ファイルパス
    cfg.band_dat = os.path.join(cfg.work_dir, "BAND.dat")
    cfg.klines = os.path.join(cfg.work_dir, "KLINES.dat")
    cfg.klabels = os.path.join(cfg.work_dir, "KLABELS")
    cfg.band_gap = os.path.join(cfg.work_dir, "BAND_GAP")

    return cfg


def read_band_dat(filepath):
    """
    Read and parse BAND.dat to extract NKPTS, NBANDS, and band data.
    Returns a tuple: (NKPTS, NBANDS, df_band)
    """
    with open(filepath, "r") as file:
        raw_lines = file.readlines()

    band_data = []
    first_line = raw_lines[0].strip()
    band_data.append(first_line.split() if first_line else [""])

    second_line = raw_lines[1].strip()
    if second_line:
        label = second_line.split(":")[0].strip()
        nums = second_line.split(":")[1].strip().split()
        NKPTS = int(nums[0])
        NBANDS = int(nums[1])
        band_data.append([label, NKPTS, NBANDS])
    else:
        band_data.append([""])
        NKPTS, NBANDS = 0, 0

    i = 2
    band_count = 0
    while i < len(raw_lines) and band_count < NBANDS:
        line = raw_lines[i].strip()
        if line == "":
            band_data.append([""])
            i += 1
            continue
        if line.startswith("#"):
            parts = line.split()
            label = " ".join(parts[:-1])
            try:
                index = float(parts[-1])
            except ValueError:
                raise ValueError(f"数値として解釈できません: '{parts[-1]}'")
            band_data.append([label, index])
            i += 1
            band_count += 1
            for _ in range(NKPTS):
                if i >= len(raw_lines):
                    raise ValueError("数値データが不足しています")
                line_data = raw_lines[i].strip()
                if line_data == "":
                    band_data.append([""])
                else:
                    band_data.append([float(n) for n in line_data.split()])
                i += 1
        else:
            band_data.append([line] if line else [""])
            i += 1

    return NKPTS, NBANDS, pd.DataFrame(band_data)


def read_klines(filepath):
    """Read KLINES.dat as a DataFrame."""
    return pd.read_csv(filepath, delim_whitespace=True, header=0)


def read_klabels(filepath):
    """Parse KLABELS file into a DataFrame."""
    rows = []
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            stripped = line.strip()
            if i == 0:
                rows.append(stripped.split(maxsplit=1))
            elif stripped == "":
                rows.append([None, None])
            elif stripped.startswith("*"):
                rows.append([stripped])
            else:
                parts = stripped.split()
                if len(parts) >= 2:
                    label, val = parts[0], parts[1]
                    try:
                        val = float(val)
                    except ValueError:
                        pass
                    rows.append([label, val])
                else:
                    rows.append([stripped])
    return pd.DataFrame(rows)


def read_band_gap(filepath):
    """Read BAND_GAP file and return parsed data as (gap_data_list, DataFrame)."""
    gap_data = []
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()
        for i, line in enumerate(lines[1:-1]):
            if ':' in line:
                key, value = line.split(':', 1)
                key = key.strip()
                values = value.strip().split()
                if i == 0:
                    gap_data.append([key] + values)
                else:
                    gap_data.append([key] + [float(v) for v in values])
    return gap_data, pd.DataFrame(gap_data)


def extract_bands(filepath, NKPTS, NBANDS):
    """
    Extract band number and corresponding (k-point, energy) data for all bands.
    Returns a list of dicts with keys: 'number', 'data'
    """
    with open(filepath, "r") as f:
        lines = f.readlines()

    bands = []
    i = 2
    band_count = 0
    while i < len(lines) and band_count < NBANDS:
        line = lines[i].strip()
        if line.startswith("#"):
            band_number = int(line.split()[-1])
            i += 1
            band_data = []
            data_rows_read = 0
            while data_rows_read < NKPTS and i < len(lines):
                line_data = lines[i].strip()
                if line_data:
                    band_data.append([float(n) for n in line_data.split()])
                    data_rows_read += 1
                i += 1
            bands.append({"number": band_number, "data": band_data})
            band_count += 1
        else:
            i += 1
    return bands


def export_to_excel_template(cfg, df_band, df_klines, df_klabels, df_gap):
    """
    マクロ付きテンプレート（.xlsm）を読み込み、openpyxl を使って各シートを追加・書き込みする。
    ・cfg.template にはテンプレートファイルパス (.xlsm) が入っている
    ・pandas を使わず、DataFrame の内容をセル単位で openpyxl に書き込む
    """
    # テンプレートを保持して読み込む (keep_vba=True)
    wb = load_workbook(cfg.template, keep_vba=True)

    # 書き込み用に各シートを用意
    # 既存シートがある場合は削除してから作成
    if "BAND" in wb.sheetnames:
        std = wb["BAND"]
        wb.remove(std)
    ws_band = wb.create_sheet(title="BAND")
    # df_band の行列を書き込む
    for r_idx, row in enumerate(df_band.values, start=1):
        for c_idx, value in enumerate(row, start=1):
            ws_band.cell(row=r_idx, column=c_idx, value=value)

    if "KLINES" in wb.sheetnames:
        std = wb["KLINES"]
        wb.remove(std)
    ws_klines = wb.create_sheet(title="KLINES")
    # ヘッダーを書き込む
    for c_idx, col_name in enumerate(df_klines.columns, start=1):
        ws_klines.cell(row=1, column=c_idx, value=col_name)
    # データ行を書き込む
    for r_idx, (_, row) in enumerate(df_klines.iterrows(), start=2):
        for c_idx, value in enumerate(row, start=1):
            ws_klines.cell(row=r_idx, column=c_idx, value=value)

    if "KLABELS" in wb.sheetnames:
        std = wb["KLABELS"]
        wb.remove(std)
    ws_klabels = wb.create_sheet(title="KLABELS")
    for r_idx, row in enumerate(df_klabels.values, start=1):
        for c_idx, value in enumerate(row, start=1):
            ws_klabels.cell(row=r_idx, column=c_idx, value=value)

    if "BAND_GAP" in wb.sheetnames:
        std = wb["BAND_GAP"]
        wb.remove(std)
    ws_gap = wb.create_sheet(title="BAND_GAP")
    for r_idx, row in enumerate(df_gap.values, start=1):
        for c_idx, value in enumerate(row, start=1):
            ws_gap.cell(row=r_idx, column=c_idx, value=value)

    # 保存 (マクロ保持)
    wb.save(cfg.template)


def add_chart_and_save(cfg, bands, df_klines, gap_data, NKPTS, NBANDS):
    """
    テンプレートにグラフ用シートを追加し、バンド構造の散布図を作成する。
    Energy range は cfg.Emin, cfg.Emax で指定
    出力先ファイルは cfg.output_excel
    """
    wb = load_workbook(cfg.template, keep_vba=True)
    ws_chart = wb.create_sheet(title="グラフ")
    ws_chart.append(["Band Number", "k-point", "Energy"])

    # 各バンドの k-point と Energy を書き込み
    for band in bands:
        for row in band["data"]:
            ws_chart.append([band["number"]] + row)

    # KLINES データを E, F 列に書き込み
    ws_chart.cell(row=1, column=5, value=df_klines.columns[0])
    ws_chart.cell(row=1, column=6, value=df_klines.columns[1])
    for i in range(len(df_klines)):
        ws_chart.cell(row=i + 2, column=5, value=df_klines.iloc[i, 0])
        ws_chart.cell(row=i + 2, column=6, value=df_klines.iloc[i, 1])

    # BAND_GAP データを H+ 列に書き込み
    for row_index, value in enumerate(gap_data, start=1):
        for col_index, item in enumerate(value, start=8):
            ws_chart.cell(row=row_index, column=col_index, value=item)

    # 散布図作成
    chart = ScatterChart()
    chart.title = "Band Structure"
    chart.y_axis.title = "Energy"
    chart.x_axis.title = "k-point"

    x_vals = Reference(ws_chart, min_col=2, min_row=2, max_row=NKPTS * NBANDS + 1)
    y_vals = Reference(ws_chart, min_col=3, min_row=2, max_row=NKPTS * NBANDS + 1)
    series_band = Series(y_vals, x_vals)
    chart.series.append(series_band)

    x_vals_line = Reference(ws_chart, min_col=5, min_row=2, max_row=len(df_klines) + 1)
    y_vals_line = Reference(ws_chart, min_col=6, min_row=2, max_row=len(df_klines) + 1)
    series_line = Series(y_vals_line, x_vals_line)
    chart.series.append(series_line)

    # Energy range を cfg で制御
    chart.y_axis.scaling.max = cfg.Emax
    chart.y_axis.scaling.min = cfg.Emin

    # X 軸範囲は全データから自動計算
    x_all = [row[0] for band in bands for row in band["data"]]
    if x_all:
        chart.x_axis.scaling.max = max(x_all)
        chart.x_axis.scaling.min = min(x_all)

    ws_chart.add_chart(chart, "M2")
    wb.save(cfg.template)

if __name__ == "__main__":
    cfg = initialize()
    NKPTS, NBANDS, df_band = read_band_dat(cfg.band_dat)
    df_klines = read_klines(cfg.klines)
    df_klabels = read_klabels(cfg.klabels)
    gap_data, df_gap = read_band_gap(cfg.band_gap)
    bands = extract_bands(cfg.band_dat, NKPTS, NBANDS)

    # openpyxl だけでテンプレートに各シートを追加してデータを書き込む
    export_to_excel_template(cfg, df_band, df_klines, df_klabels, df_gap)
    add_chart_and_save(cfg, bands, df_klines, gap_data, NKPTS, NBANDS)
