import pandas as pd
from openpyxl import load_workbook
from openpyxl.chart import ScatterChart, Series, Reference


def read_band_dat(filepath):
    """
    Read and parse BAND.dat to extract NKPTS, NBANDS, and band data.
    Returns a tuple: (NKPTS, NBANDS, df_band)
    """
    with open(filepath, "r") as file:
        raw_lines = file.readlines()

    band_data = []
    first_line = raw_lines[0].strip()
    band_data.append(first_line.split() if first_line else [""])

    second_line = raw_lines[1].strip()
    if second_line:
        label = second_line.split(":")[0].strip()
        nums = second_line.split(":")[1].strip().split()
        NKPTS = int(nums[0])
        NBANDS = int(nums[1])
        band_data.append([label, NKPTS, NBANDS])
    else:
        band_data.append([""])

    i = 2
    band_count = 0
    while i < len(raw_lines) and band_count < NBANDS:
        line = raw_lines[i].strip()
        if line == "":
            band_data.append([""])
            i += 1
            continue
        if line.startswith("#"):
            parts = line.split()
            label = " ".join(parts[:-1])
            index = float(parts[-1])
            band_data.append([label, index])
            i += 1
            band_count += 1
            for _ in range(NKPTS):
                if i >= len(raw_lines):
                    raise ValueError("数値データが不足しています")
                line_data = raw_lines[i].strip()
                if line_data == "":
                    band_data.append([""])
                else:
                    band_data.append([float(n) for n in line_data.split()])
                i += 1
        else:
            band_data.append([line] if line else [""])
            i += 1

    return NKPTS, NBANDS, pd.DataFrame(band_data)


def read_klines(filepath):
    """Read KLINES.dat as a DataFrame."""
    return pd.read_csv(filepath, delim_whitespace=True, header=0)


def read_klabels(filepath):
    """Parse KLABELS file into a DataFrame."""
    rows = []
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            stripped = line.strip()
            if i == 0:
                rows.append(stripped.split(maxsplit=1))
            elif stripped == "":
                rows.append([None, None])
            elif stripped.startswith("*"):
                rows.append([stripped])
            else:
                parts = stripped.split()
                if len(parts) >= 2:
                    label, val = parts[0], parts[1]
                    try:
                        val = float(val)
                    except ValueError:
                        pass
                    rows.append([label, val])
                else:
                    rows.append([stripped])
    return pd.DataFrame(rows)


def read_band_gap(filepath):
    """Read BAND_GAP file and return parsed data as list of lists and DataFrame."""
    gap_data = []
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()
        for i, line in enumerate(lines[1:-1]):
            if ':' in line:
                key, value = line.split(':', 1)
                key = key.strip()
                values = value.strip().split()
                if i == 0:
                    gap_data.append([key] + values)
                else:
                    gap_data.append([key] + [float(v) for v in values])
    return gap_data, pd.DataFrame(gap_data)


def extract_bands(filepath, NKPTS, NBANDS):
    """
    Extract band number and corresponding (k-point, energy) data for all bands.
    Returns a list of dicts with keys: 'number', 'data'
    """
    with open(filepath, "r") as f:
        lines = f.readlines()

    bands = []
    i = 2
    band_count = 0
    while i < len(lines) and band_count < NBANDS:
        line = lines[i].strip()
        if line.startswith("#"):
            band_number = int(line.split()[-1])
            i += 1
            band_data = []
            for _ in range(NKPTS):
                line_data = lines[i].strip()
                if line_data:
                    band_data.append([float(n) for n in line_data.split()])
                i += 1
            bands.append({"number": band_number, "data": band_data})
            band_count += 1
        else:
            i += 1
    return bands


def export_to_excel(filename, df_band, df_klines, df_klabels, df_gap):
    """Write dataframes to Excel file."""
    with pd.ExcelWriter(filename, engine="openpyxl") as writer:
        df_band.to_excel(writer, sheet_name="BAND", index=False, header=False)
        df_klines.to_excel(writer, sheet_name="KLINES", index=False)
        df_klabels.to_excel(writer, sheet_name="KLABELS", index=False, header=False)
        df_gap.to_excel(writer, sheet_name="BAND_GAP", index=False, header=False)


def add_chart_and_save(filename, bands, df_klines, gap_data, NKPTS, NBANDS):
    """
    Add a new sheet with k-point/energy data and create a scatter chart.
    """
    wb = load_workbook(filename)
    ws_chart = wb.create_sheet(title="グラフ")
    ws_chart.append(["Band Number", "k-point", "Energy"])

    for band in bands:
        for row in band["data"]:
            ws_chart.append([band["number"]] + row)

    # Add KLINES data to E,F columns
    ws_chart.cell(row=1, column=5, value=df_klines.columns[0])
    ws_chart.cell(row=1, column=6, value=df_klines.columns[1])
    for i in range(len(df_klines)):
        ws_chart.cell(row=i + 2, column=5, value=df_klines.iloc[i, 0])
        ws_chart.cell(row=i + 2, column=6, value=df_klines.iloc[i, 1])

    # Add BAND_GAP data to H+ columns
    for row_index, value in enumerate(gap_data, start=1):
        for col_index, item in enumerate(value, start=8):
            ws_chart.cell(row=row_index, column=col_index, value=item)

    # Create chart
    chart = ScatterChart()
    chart.title = "Band Structure"
    chart.y_axis.title = "Energy"
    chart.x_axis.title = "k-point"

    x_vals = Reference(ws_chart, min_col=2, min_row=2, max_row=NKPTS * NBANDS + 1)
    y_vals = Reference(ws_chart, min_col=3, min_row=2, max_row=NKPTS * NBANDS + 1)
    series_band = Series(y_vals, x_vals)
    chart.series.append(series_band)

    x_vals_line = Reference(ws_chart, min_col=5, min_row=2, max_row=len(df_klines) + 1)
    y_vals_line = Reference(ws_chart, min_col=6, min_row=2, max_row=len(df_klines) + 1)
    series_line = Series(y_vals_line, x_vals_line)
    chart.series.append(series_line)

    ws_chart.add_chart(chart, "M2")

    # Axis scaling (example)
    x_max = max(sublist[0] for sublist in bands[0]["data"])
    x_min = min(sublist[0] for sublist in bands[0]["data"])
    chart.x_axis.scaling.max = x_max
    chart.x_axis.scaling.min = x_min
    chart.y_axis.scaling.max = 9
    chart.y_axis.scaling.min = -5

    wb.save(filename)


if __name__ == "__main__":
    NKPTS, NBANDS, df_band = read_band_dat("BAND.dat")
    df_klines = read_klines("KLINES.dat")
    df_klabels = read_klabels("KLABELS")
    gap_data, df_gap = read_band_gap("BAND_GAP")
    bands = extract_bands("BAND.dat", NKPTS, NBANDS)
    excel_filename = "combined_output.xlsx"
    export_to_excel(excel_filename, df_band, df_klines, df_klabels, df_gap)
    add_chart_and_save(excel_filename, bands, df_klines, gap_data, NKPTS, NBANDS)