import pandas as pd
from openpyxl import load_workbook
from openpyxl.chart import LineChart, Reference
from openpyxl import Workbook
from openpyxl.chart import ScatterChart, Series, Reference


# ===== 1. データ読み込みとDataFrame化 =====

# --- BAND.dat 処理 ---
with open("BAND.dat", "r") as file:
    raw_lines = file.readlines()

band_data = []

first_line = raw_lines[0].strip()
band_data.append(first_line.split() if first_line else [""])

second_line = raw_lines[1].strip()
if second_line:
    label = second_line.split(":")[0].strip()
    nums = second_line.split(":")[1].strip().split()
    NKPTS = int(nums[0])
    NBANDS = int(nums[1])
    band_data.append([label, NKPTS, NBANDS])
else:
    band_data.append([""])

i = 2
band_count = 0

while i < len(raw_lines) and band_count < NBANDS:
    line = raw_lines[i].rstrip("\n")
    if line.strip() == "":
        band_data.append([""])
        i += 1
        continue

    if line.strip().startswith("#"):
        parts = line.strip().split()
        index_label = " ".join(parts[:-1])
        try:
            index_number = float(parts[-1])
        except ValueError:
            raise ValueError(f"数値として解釈できません: '{parts[-1]}'")
        band_data.append([index_label, index_number])
        i += 1
        band_count += 1

        for _ in range(NKPTS):
            if i >= len(raw_lines):
                raise ValueError("数値データが不足しています")
            num_line = raw_lines[i].strip()
            if num_line == "":
                band_data.append([""])
            else:
                numbers = [float(n) for n in num_line.split()]
                band_data.append(numbers)
            i += 1
    else:
        band_data.append([line.strip()] if line.strip() else [""])
        i += 1

df_band = pd.DataFrame(band_data)

# --- KLINES.dat 処理 ---
df_klines = pd.read_csv('KLINES.dat', delim_whitespace=True, header=0)

# --- KLABELS 処理 ---
rows_klabels = []
with open("KLABELS", "r", encoding="utf-8") as f:
    for i, line in enumerate(f):
        stripped = line.strip()
        if i == 0:
            parts = stripped.split(maxsplit=1)
            rows_klabels.append(parts)
        elif stripped == "":
            rows_klabels.append([None, None])
        elif stripped.startswith("*"):
            rows_klabels.append([stripped])
        else:
            parts = stripped.split()
            if len(parts) >= 2:
                label = parts[0]
                try:
                    value = float(parts[1])
                except ValueError:
                    value = parts[1]
                rows_klabels.append([label, value])
            else:
                rows_klabels.append([stripped])

df_klabels = pd.DataFrame(rows_klabels)

# --- BAND_GAP 処理 ---
gap_data = []
with open('BAND_GAP', 'r', encoding='utf-8') as f:
    lines = f.readlines()

# 1行目と最終行を除いた部分を処理
    for i, line in enumerate(lines[1:-1]):
        if ':' in line:
            key, value = line.split(':', 1)
            key = key.strip()
            values = value.strip().split()
            if i == 0:
                gap_data.append([key] + values)
            else:
                float_values = [float(v) for v in values]
                gap_data.append([key] + float_values)

# DataFrameの作成
df_gap = pd.DataFrame(gap_data)


# ===== 2. グラフ用 BAND.dat 処理 =====

# --- BAND.dat 読み込み ---
with open("BAND.dat", "r") as f:
    raw_lines = f.readlines()

# --- ヘッダーとパラメータ読み取り ---
header = raw_lines[0].strip().split()
parts = raw_lines[1].strip().split(":")
numbers = parts[1].strip().split()
NKPTS = int(numbers[0])  # k-pointの数
NBANDS = int(numbers[1])  # バンド数

# --- バンド情報を保存する ---
bands = []
i = 2
band_count = 0

while i < len(raw_lines) and band_count < NBANDS:
    line = raw_lines[i].strip()

    if line.startswith("#"):
        parts = line.strip().split()
        band_number = int(parts[-1])
        i += 1

        band_data = []
        data_rows_read = 0
        while data_rows_read < NKPTS and i < len(raw_lines):
            line_data = raw_lines[i].strip()
            if line_data == "":
                i += 1
                continue
            nums = [float(n) for n in line_data.split()]
            band_data.append(nums)
            i += 1
            data_rows_read += 1

        bands.append({
            "number": band_number,
            "data": band_data
        })

        band_count += 1
    else:
        i += 1


# ===== 3. Excel出力（pandasで一旦保存） =====
excel_filename = "combined_output.xlsx"
with pd.ExcelWriter(excel_filename, engine="openpyxl") as writer:
    df_band.to_excel(writer, sheet_name="BAND", index=False, header=False)
    df_klines.to_excel(writer, sheet_name="KLINES", index=False, header=True)
    df_klabels.to_excel(writer, sheet_name="KLABELS", index=False, header=False)
    df_gap.to_excel(writer, sheet_name="BAND_GAP", index=False, header=False)


# ===== 4. openpyxlでグラフ用シート&データ追加 =====
wb = load_workbook(excel_filename)
ws_chart = wb.create_sheet(title="グラフ")

ws_chart.append(["Band Number", "k-point", "Energy"])

# データ書き出し（全バンドを1つのシートに）
for band in bands:
    number = band["number"]
    for row in band["data"]:
        ws_chart.append([number] + row)

# band_data を A列（列番号1）に書き込む
#for row_index, value in enumerate(band_data, start=1):
#    if isinstance(value, list):
#        for col_index, item in enumerate(value, start=1):  # A列から開始
#            ws_chart.cell(row=row_index, column=col_index, value=item)
#    else:
#        ws_chart.cell(row=row_index, column=1, value=value)
        
# df_klines を E,F列（列番号5,6）に書き込む
ws_chart.cell(row=1, column=5, value=df_klines.columns[0]) # E列のヘッダー
ws_chart.cell(row=1, column=6, value=df_klines.columns[1]) # F列のヘッダー
for i in range(len(df_klines)):
    ws_chart.cell(row=i+2, column=5, value=df_klines.iloc[i, 0])  # E列 = column 5, data
    ws_chart.cell(row=i+2, column=6, value=df_klines.iloc[i, 1])  # F列 = column 6, data

# gap_data を H列（列番号8）に書き込む
for row_index, value in enumerate(gap_data, start=1):
    if isinstance(value, list):
        for col_index, item in enumerate(value, start=8):  # H列から開始
            ws_chart.cell(row=row_index, column=col_index, value=item)
    else:
        ws_chart.cell(row=row_index, column=8, value=value)
        
        
# ===== 5. openpyxlでグラフ作成 =====

# --- 散布図作成 ---
chart = ScatterChart()
chart.title = "Band Structure"
chart.y_axis.title = "Energy"
chart.x_axis.title = "k-point"

# 各バンドのデータを散布図に追加
x_values_band = Reference(ws_chart, min_col=2, min_row=2, max_row=NKPTS*NBANDS + 1)  # X軸：k-points（B列）
y_values_band = Reference(ws_chart, min_col=3, min_row=2, max_row=NKPTS*NBANDS + 1)  # Y軸：エネルギー(C列)
series_band = Series(y_values_band,x_values_band)
chart.series.append(series_band)

# klinesを散布図に追加
x_values_line = Reference(ws_chart, min_col=5, min_row=2, max_row=len(df_klines)+1)  # X軸：k-points（E列）
y_values_line = Reference(ws_chart, min_col=6, min_row=2, max_row=len(df_klines)+1)  # X軸：k-points（F列）
series_line = Series(y_values_line,x_values_line)
chart.series.append(series_line)

# グラフ位置を指定（例：M2セル）
ws_chart.add_chart(chart, "M2")

# グラフの書式設定
# グラフの軸の最大値・最小値を検出
x_max = max(sublist[0] for sublist in band["data"])
x_min = min(sublist[0] for sublist in band["data"])
y_max = max(sublist[1] for sublist in band["data"])
y_min = min(sublist[1] for sublist in band["data"])
# グラフの軸の最大値・最小値を設定
chart.x_axis.scaling.max = x_max
chart.x_axis.scaling.min = x_min
chart.y_axis.scaling.max = 9
chart.y_axis.scaling.min = -5

# 保存
wb.save(excel_filename)

