import os
import argparse
import sys
import types
import csv
from openpyxl import load_workbook
from openpyxl.chart import ScatterChart, Series, Reference

# ギリシャ文字マッピング辞書
GREEK_MAP = {
    "GAMMA": "Γ",
    "ALPHA": "Α",
    "BETA": "Β",
    "DELTA": "Δ",
    "EPSILON": "Ε",
    "ZETA": "Ζ",
    "ETA": "Η",
    "THETA": "Θ",
    "IOTA": "Ι",
    "KAPPA": "Κ",
    "LAMDA": "Λ",
    "MU": "Μ",
    "NU": "Ν",
    "XI": "Ξ",
    "OMICRON": "Ο",
    "PI": "Π",
    "RHO": "Ρ",
    "SIGMA": "Σ",
    "TAU": "Τ",
    "UPSILON": "Υ",
    "PHI": "Φ",
    "CHI": "Χ",
    "PSI": "Ψ",
    "OMEGA": "Ω",
    #小文字も追加
    "gamma": "γ",
    "alpha": "α",
    "beta": "β",
    "delta": "δ",
    "epsilon": "ε",     
    "zeta": "ζ",
    "eta": "η",
    "theta": "θ",
    "iota": "ι",
    "kappa": "κ",
    "lamda": "λ",
    "mu": "μ",
    "nu": "ν",
    "xi": "ξ",
    "omicron": "ο",
    "pi": "π",  
    "rho": "ρ",
    "sigma": "σ",
    "tau": "τ",
    "upsilon": "υ",
    "phi": "φ",
    "chi": "χ",
    "psi": "ψ",
    "omega": "ω"       
}


def initialize():
    """
    Parse command-line arguments and return a SimpleNamespace `cfg` containing:
      - work_dir: 作業ディレクトリ
      - Emax: バンド図のエネルギー最大値
      - Emin: バンド図のエネルギー最小値
      - template: マクロ付きテンプレートファイル (.xlsm) のパス
      - output_excel: 出力先のExcelファイル名
      - file paths: template (.xlsm) と出力ファイルパスなど
    """
    parser = argparse.ArgumentParser(description="VASPバンド構造解析とExcel出力")
    parser.add_argument(
        "--work_dir",
        type=str,
        default=".",
        help="作業ディレクトリ (デフォルト: カレントディレクトリ)"
    )
    parser.add_argument(
        "--Emax",
        type=float,
        default=9.0,
        help="バンド図のエネルギー範囲 最大値 (デフォルト: 9.0)"
    )
    parser.add_argument(
        "--Emin",
        type=float,
        default=-5.0,
        help="バンド図のエネルギー範囲 最小値 (デフォルト: -5.0)"
    )
    parser.add_argument(
        "--template",
        type=str,
        default="StandardGraph.xlsm",
        help="マクロ付きテンプレートファイル (.xlsm) の名前 (デフォルト: StandardGraph.xlsm)"
    )
    parser.add_argument(
        "--output_excel",
        type=str,
        default=None,
        help="出力先のExcelファイル名 (デフォルト: combined_output.xlsm)"
    )
    args = parser.parse_args()

    cfg = types.SimpleNamespace()
    cfg.work_dir = args.work_dir
    cfg.Emax = args.Emax
    cfg.Emin = args.Emin

    # テンプレートファイルと出力ファイルを設定
    cfg.template = os.path.join(cfg.work_dir, args.template)
    #cfg.output_excel is Noneのとき、次のデフォルトを代入
    if args.output_excel is None:
        cfg.output_excel = os.path.join(cfg.work_dir, "combined_output.xlsm")
    else:
        cfg.output_excel = args.output_excel

    # 入力ファイルパス
    cfg.band_dat = os.path.join(cfg.work_dir, "BAND.dat")
    cfg.klines = os.path.join(cfg.work_dir, "KLINES.dat")
    cfg.klabels = os.path.join(cfg.work_dir, "KLABELS")
    cfg.band_gap = os.path.join(cfg.work_dir, "BAND_GAP")
    # EIGENVAL, INCAR, KPOINTSのファイルパスを作成
    cfg.eigenval = os.path.join(cfg.work_dir, "EIGENVAL")
    cfg.incar = os.path.join(cfg.work_dir, "INCAR")
    cfg.kpoints = os.path.join(cfg.work_dir, "KPOINTS")
    
    return cfg


def read_band_dat(filepath):
    """
    Read and parse BAND.dat to extract NKPTS, NBANDS, and band data.
    Returns a tuple: (NKPTS, NBANDS, band_list)
      - band_list: 各行をリストにしたリスト（空行は [""]、ヘッダー行は [ラベル, 数値]、数値行は [float1, float2, ...]）
    """
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません: {filepath}")
        sys.exit(1)

    print(f"band.datファイルを読み込みます: {filepath}")

    with open(filepath, "r") as file:
        raw_lines = file.readlines()

    band_list = []
    first_line = raw_lines[0].strip() if raw_lines else ""
    band_list.append(first_line.split() if first_line else [""])

    if len(raw_lines) < 2:
        print(f"Error: {filepath} の内容が不正です。行数が不足しています。")
        sys.exit(1)

    second_line = raw_lines[1].strip()
    if second_line:
        parts = second_line.split(":")
        if len(parts) < 2:
            print(f"Error: {filepath} の2行目のフォーマットが不正です。")
            sys.exit(1)
        nums = parts[1].strip().split()
        if len(nums) < 2:
            print(f"Error: {filepath} の2行目に NKPTS, NBANDS がありません。")
            sys.exit(1)
        NKPTS = int(nums[0])
        NBANDS = int(nums[1])
        band_list.append([parts[0].strip(), NKPTS, NBANDS])
    else:
        band_list.append([""])
        NKPTS, NBANDS = 0, 0

    i = 2
    band_count = 0
    while i < len(raw_lines) and band_count < NBANDS:
        line = raw_lines[i].strip()
        if line == "":
            band_list.append([""])
            i += 1
            continue
        if line.startswith("#"):
            parts = line.split()
            if len(parts) < 2 or not parts[-1].replace(".", "", 1).isdigit():
                print(f"Error: {filepath} ヘッダー行が不正です: {line}")
                sys.exit(1)
            label = " ".join(parts[:-1])
            index = float(parts[-1])
            band_list.append([label, index])
            i += 1
            band_count += 1
            for _ in range(NKPTS):
                if i >= len(raw_lines):
                    print(f"Error: {filepath} の数値データが不足しています。")
                    sys.exit(1)
                line_data = raw_lines[i].strip()
                if line_data == "":
                    band_list.append([""])
                else:
                    nums = line_data.split()
                    if not all(item.replace(".", "", 1).replace("-", "", 1).isdigit() for item in nums):
                        print(f"Error: {filepath} の数値行に不正な文字があります: {line_data}")
                        sys.exit(1)
                    band_list.append([float(n) for n in nums])
                i += 1
        else:
            band_list.append([line] if line else [""])
            i += 1

    return NKPTS, NBANDS, band_list


def read_klines(filepath):
    """
    Read KLINES.dat using csv.reader (delimiter=whitespace).
    Returns klines_list: [ [header_col1, header_col2], [val1, val2], ... ]
    """
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        sys.exit(1)

    print(f"KLINES.datファイルを読み込みます: {filepath}")
    klines_list = []
    with open(filepath, newline="") as csvfile:
        reader = csv.reader(csvfile, delimiter=" ", skipinitialspace=True)
        for row in reader:
            if not row:
                continue
            klines_list.append(row)
    if not klines_list:
        print(f"Error: {filepath} にデータがありません")
        sys.exit(1)
    return klines_list


def read_klabels(filepath):
    """
    Parse KLABELS file into a list-of-lists, replacing English names with Greek letters.
    Returns klabels_list: [ [col1, col2], ... ] もしくは [ [label], ... ] の混在。
    """
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        sys.exit(1)

    print(f"KLABELSファイルを読み込みます: {filepath}")
    klabels_list = []
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            stripped = line.strip()
            if i == 0:
                entries = stripped.split(maxsplit=1)
                if entries and entries[0].upper() in GREEK_MAP:
                    entries[0] = GREEK_MAP[entries[0].upper()]
                klabels_list.append(entries)
            elif stripped == "":
                klabels_list.append([None, None])
            elif stripped.startswith("*"):
                label = stripped
                if label.upper() in GREEK_MAP:
                    label = GREEK_MAP[label.upper()]
                klabels_list.append([label])
            else:
                parts = stripped.split()
                if len(parts) >= 2:
                    label = parts[0]
                    val = parts[1]
                    if label.upper() in GREEK_MAP:
                        label = GREEK_MAP[label.upper()]
                    if not val.replace(".", "", 1).replace("-", "", 1).isdigit():
                        klabels_list.append([label, val])
                    else:
                        klabels_list.append([label, float(val)])
                else:
                    label = parts[0]
                    if label.upper() in GREEK_MAP:
                        label = GREEK_MAP[label.upper()]
                    klabels_list.append([label])
    if not klabels_list:
        print(f"Error: {filepath} にデータがありません")
        sys.exit(1)
    return klabels_list


def read_band_gap(filepath):
    """
    Read BAND_GAP file and return parsed data as gap_list: [ [key, val1, val2, ...], ... ].
    """
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        sys.exit(1)

    print(f"BAND_GAPファイルを読み込みます: {filepath}")
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()
    if len(lines) < 2:
        print(f"Error: {filepath} の内容が不正です。行数が不足しています。")
        sys.exit(1)

    gap_list = []
    for i, line in enumerate(lines[1:-1]):
        if ":" in line:
            key, value = line.split(":", 1)
            key = key.strip()
            values = value.strip().split()
            if i == 0:
                gap_list.append([key] + values)
            else:
                if not all(item.replace(".", "", 1).replace("-", "", 1).isdigit() for item in values):
                    print(f"Error: BAND_GAP の数値変換に失敗: {values}")
                    sys.exit(1)
                gap_list.append([key] + [float(v) for v in values])
    if not gap_list:
        print(f"Error: {filepath} にデータがありません")
        sys.exit(1)
    return gap_list


def extract_bands(filepath, NKPTS, NBANDS):
    """
    Extract band number and corresponding (k-point, energy) data for all bands.
    Returns a list of dicts: [{"number": バンド番号, "data": [[k-point, energy], ...]}, ...].
    """
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        sys.exit(1)

    print(f"BAND.datファイルを読み込みます: {filepath}")
    with open(filepath, "r") as f:
        lines = f.readlines()

    bands = []
    i = 2
    band_count = 0
    while i < len(lines) and band_count < NBANDS:
        line = lines[i].strip()
        if line.startswith("#"):
            parts = line.split()
            if len(parts) < 2 or not parts[-1].replace(".", "", 1).isdigit():
                print(f"Error: BAND.dat のバンド番号行が不正です: {line}")
                sys.exit(1)
            band_number = int(parts[-1])
            i += 1
            band_data = []
            data_rows_read = 0
            while data_rows_read < NKPTS and i < len(lines):
                line_data = lines[i].strip()
                if line_data:
                    nums = line_data.split()
                    if not all(item.replace(".", "", 1).replace("-", "", 1).isdigit() for item in nums):
                        print(f"Error: BAND.dat の数値行に不正な文字があります: {line_data}")
                        sys.exit(1)
                    band_data.append([float(n) for n in nums])
                    data_rows_read += 1
                i += 1
            bands.append({"number": band_number, "data": band_data})
            band_count += 1
        else:
            i += 1

    return bands


def export_to_excel_template(cfg, band_list, klines_list, klabels_list, gap_list):
    """
    マクロ付きテンプレート（.xlsm）を読み込み、openpyxl を使って各シートを追加・書き込みする。
    ・cfg.template にはテンプレートファイルパス (.xlsm) が入っている
    ・pandas を使わず、list-of-lists の内容をセル単位で openpyxl に書き込む
    """
    if not os.path.exists(cfg.template):
        print(f"Error: テンプレートファイルが見つかりません: {cfg.template}")
        sys.exit(1)

    if not os.access(cfg.output_excel, os.W_OK):
        print(f"Error: 出力ファイルに書き込み権限がありません: {cfg.output_excel}")
        sys.exit(1)

    print(f"テンプレートファイルを読み込みます: {cfg.template}")
    wb = load_workbook(cfg.template, keep_vba=True)

    # BAND シートを作成・上書き
    if "BAND" in wb.sheetnames:
        wb.remove(wb["BAND"])
    ws_band = wb.create_sheet(title="BAND")
    for r_idx, row in enumerate(band_list, start=1):
        for c_idx, value in enumerate(row, start=1):
            ws_band.cell(row=r_idx, column=c_idx, value=value)

    # KLINES シートを作成・上書き
    if "KLINES" in wb.sheetnames:
        wb.remove(wb["KLINES"])
    ws_klines = wb.create_sheet(title="KLINES")
    header = klines_list[0]
    data_rows = klines_list[1:]
    for c_idx, col_name in enumerate(header, start=1):
        ws_klines.cell(row=1, column=c_idx, value=col_name)
    for r_idx, row in enumerate(data_rows, start=2):
        for c_idx, value in enumerate(row, start=1):
            # 数値文字列なら float に変換
            if isinstance(value, str) and value.replace(".", "", 1).replace("-", "", 1).isdigit():
                ws_klines.cell(row=r_idx, column=c_idx, value=float(value))
            else:
                ws_klines.cell(row=r_idx, column=c_idx, value=value)

    # KLABELS シートを作成・上書き
    if "KLABELS" in wb.sheetnames:
        wb.remove(wb["KLABELS"])
    ws_klabels = wb.create_sheet(title="KLABELS")
    for r_idx, row in enumerate(klabels_list, start=1):
        for c_idx, value in enumerate(row, start=1):
            ws_klabels.cell(row=r_idx, column=c_idx, value=value)

    # BAND_GAP シートを作成・上書き
    if "BAND_GAP" in wb.sheetnames:
        wb.remove(wb["BAND_GAP"])
    ws_gap = wb.create_sheet(title="BAND_GAP")
    for r_idx, row in enumerate(gap_list, start=1):
        for c_idx, value in enumerate(row, start=1):
            ws_gap.cell(row=r_idx, column=c_idx, value=value)

    # 保存 (マクロ保持)
    print(f"Excelファイルに書き込みます: {cfg.output_excel}")   
    wb.save(cfg.output_excel)


def add_chart_and_save(cfg, bands, klines_list, gap_list, NKPTS, NBANDS):
    """
    テンプレートにグラフ用シートを追加し、バンド構造の散布図を作成する。
    Energy range は cfg.Emin, cfg.Emax で指定
    出力先ファイルは cfg.output_excel
    """
    #ファイルの書き込み許可確認
    if not os.access(cfg.output_excel, os.W_OK):
        print(f"Error: 出力ファイルに書き込み権限がありません: {cfg.output_excel}")
        sys.exit(1)
    wb = load_workbook(cfg.output_excel, keep_vba=True)
    ws_chart = wb.create_sheet(title="グラフ")
    ws_chart.append(["Band Number", "k-point", "Energy"])

    # 各バンドの k-point と Energy を書き込み
    for band in bands:
        for row in band["data"]:
            ws_chart.append([band["number"]] + row)

    # KLINES データを E, F 列に書き込み
    header = klines_list[0]
    data_rows = klines_list[1:]
    # ヘッダー行セルに書き込み
    ws_chart.cell(row=1, column=5, value=header[0])
    ws_chart.cell(row=1, column=6, value=header[1])
    # データ行
    for i, row in enumerate(data_rows, start=2):
        val0, val1 = row[0], row[1]
        if isinstance(val0, str) and val0.replace(".", "", 1).replace("-", "", 1).isdigit():
            ws_chart.cell(row=i, column=5, value=float(val0))
        else:
            ws_chart.cell(row=i, column=5, value=val0)
        if isinstance(val1, str) and val1.replace(".", "", 1).replace("-", "", 1).isdigit():
            ws_chart.cell(row=i, column=6, value=float(val1))
        else:
            ws_chart.cell(row=i, column=6, value=val1)

    # BAND_GAP データを H+ 列に書き込み
    for row_index, row in enumerate(gap_list, start=1):
        for col_offset, item in enumerate(row, start=8):
            ws_chart.cell(row=row_index, column=col_offset, value=item)

    # 散布図作成
    chart = ScatterChart()
    chart.title = "Band Structure"
    chart.y_axis.title = "Energy"
    chart.x_axis.title = "k-point"

    x_vals = Reference(ws_chart, min_col=2, min_row=2, max_row=NKPTS * NBANDS + 1)
    y_vals = Reference(ws_chart, min_col=3, min_row=2, max_row=NKPTS * NBANDS + 1)
    series_band = Series(y_vals, x_vals)
    chart.series.append(series_band)

    # KLINES のラインを追加
    num_klines = len(data_rows)
    x_vals_line = Reference(ws_chart, min_col=5, min_row=2, max_row=num_klines + 1)
    y_vals_line = Reference(ws_chart, min_col=6, min_row=2, max_row=num_klines + 1)
    series_line = Series(y_vals_line, x_vals_line)
    chart.series.append(series_line)

    # Energy range を cfg で制御
    chart.y_axis.scaling.max = cfg.Emax
    chart.y_axis.scaling.min = cfg.Emin

    # X 軸範囲は全バンドデータから自動計算
    x_all = [row[0] for band in bands for row in band["data"]]
    if x_all:
        chart.x_axis.scaling.max = max(x_all)
        chart.x_axis.scaling.min = min(x_all)

    ws_chart.add_chart(chart, "M2")
    print(f"グラフを追加し、Excelファイルに保存しました: {cfg.output_excel}")
    wb.save(cfg.output_excel)


def main():
    cfg = initialize()

    # 各ファイルをパースしてリストに変換
    NKPTS, NBANDS, band_list = read_band_dat(cfg.band_dat)
    klines_list = read_klines(cfg.klines)
    klabels_list = read_klabels(cfg.klabels)
    gap_list = read_band_gap(cfg.band_gap)
    bands = extract_bands(cfg.band_dat, NKPTS, NBANDS)

    # テンプレートに各シートを追加してデータを書き込む
    export_to_excel_template(cfg, band_list, klines_list, klabels_list, gap_list)

    # グラフ作成および保存
    add_chart_and_save(cfg, bands, klines_list, gap_list, NKPTS, NBANDS)


if __name__ == "__main__":
    main()
