import os
import sys
import argparse
import types
from openpyxl.utils import get_column_letter

from vasp_lib import *


def initialize():
    """
    Parse command-line arguments and return a SimpleNamespace `cfg` containing:
      - work_dir: 作業ディレクトリ
      - Emax: バンド図のエネルギー最大値
      - Emin: バンド図のエネルギー最小値
      - template: マクロ付きテンプレートファイル (.xlsm) のパス
      - output_excel: 出力先のExcelファイル名
      - file paths: template (.xlsm) と出力ファイルパスなど
    """
    parser = argparse.ArgumentParser(description="VASPバンド構造解析とExcel出力")
    parser.add_argument(
        "--mode",
        type=str,
        default="",
        help="動作モード [band|dos] (デフォルト: キーボード入力)"
    )
    parser.add_argument(
        "--work_dir",
        type=str,
        default=".",
        help="作業ディレクトリ (デフォルト: カレントディレクトリ)"
    )
    parser.add_argument(
        "--Emax",
        type=float,
        default=9.0,
        help="バンド図のエネルギー範囲 最大値 (デフォルト: 9.0)"
    )
    parser.add_argument(
        "--Emin",
        type=float,
        default=-5.0,
        help="バンド図のエネルギー範囲 最小値 (デフォルト: -5.0)"
    )
    parser.add_argument(
        "--template",
        type=str,
        default="StandardGraph.xlsm",
        help="マクロ付きテンプレートファイル (.xlsm) の名前 (デフォルト: StandardGraph.xlsm)"
    )
    parser.add_argument(
        "--output_excel",
        type=str,
        default=None,
        help="出力先のExcelファイル名 (デフォルト: band.xlsm)"
    )
    parser.add_argument(
        "--output_excel",
        type=str,
        default=None,
        help="出力先のExcelファイル名 (デフォルト: band.xlsm)"
    )
    parser.add_argument(
        "--read_vaspkit_files",
        type=int,
        default=1,
        help="VASPKITファイルを読み込むかどうか (デフォルト: 1)"
    )
    parser.add_argument(
        "--height",
        type=float,
        default=300.0,
        help="グラフの高さ (デフォルト: 300.0)"
    )
    parser.add_argument(
        "--kgraphsize",
        type=float,
        default=1.0/30.0,
        help="グラフサイズの係数 (デフォルト: 0.0333)"
    )
    parser.add_argument(
        "--fontsize",
        type=float,
        default=16.0,
        help="グラフのフォントサイズ (デフォルト: 16.0)"
    )
    parser.add_argument(
        "--chart_cell",
        type=str,
        default="J11",
        help="グラフを配置するセル (デフォルト: J11)"
    )    
    parser.add_argument(
        "--pause",
        type=int,
        default=0,
        help="Flag to pause at termination"
    )    
    args = parser.parse_args()

    #argsの値をcfgに格納
    cfg = types.SimpleNamespace()
    for key, val in vars(args).items():
        setattr(cfg, key, val)

    cfg.work_dir = args.work_dir
    cfg.Emax = args.Emax
    cfg.Emin = args.Emin

    if cfg.mode == "":
        cfg.mode = input("\nChoose band or dos>> ")

    # テンプレートファイルと出力ファイルを設定
    cfg.template = os.path.join(cfg.work_dir, args.template)
    #cfg.output_excel is Noneのとき、次のデフォルトを代入
    if args.output_excel is None:
        cfg.output_excel = os.path.join(cfg.work_dir, f"{cfg.mode}.xlsm")
    else:
        cfg.output_excel = args.output_excel

    # 入力ファイルパス
    cfg.band_dat = os.path.join(cfg.work_dir, "BAND.dat")
    cfg.klines = os.path.join(cfg.work_dir, "KLINES.dat")
    cfg.klabels = os.path.join(cfg.work_dir, "KLABELS")
    cfg.band_gap = os.path.join(cfg.work_dir, "BAND_GAP")

    cfg.poscar = os.path.join(cfg.work_dir, "POSCAR")
    cfg.eigenval = os.path.join(cfg.work_dir, "EIGENVAL")
    cfg.incar = os.path.join(cfg.work_dir, "INCAR")
    cfg.kpoints = os.path.join(cfg.work_dir, "KPOINTS")
    
    return cfg

def terminate(cfg, message = None):
    if message is not None:
        print(message)
        
    if cfg.pause:
        input("\nPress ENTER to terminate>>\n")

def reformat_workbook(wb):
    ws_main = get_worksheet(wb, "Sheet1")

    # VBAのシートではcellに名前を付けている
    # openpyxlでシート名を変えると、cellの名前のリンクが壊れるので、名前は変えられない
#   ws_main.title = "Main"

    # MainシートのA~E列のデータをすべて削除
    for col in ws_main.iter_cols(min_col=1, max_col=5, min_row=1, max_row=ws_main.max_row):
        for cell in col:
            cell.value = None

    #wbの全シートからChartを探して削除
    for sheet in wb.sheetnames:
        ws = wb[sheet]
        for chart in ws._charts:
            ws._charts.remove(chart)

    # vasp2xlsmの追加シートが存在する場合は削除
    for sheetname in ["BAND", "KLINES", "KLABELS", "BAND_GAP", "GRAPH"]:
        if sheetname in wb.sheetnames:
            wb.remove(wb[sheetname])

    return ws_main

def update_workbook(cfg, wb):
    print()
    print("update parameters")
    update_cell(wb, "FontSize", cfg.fontsize)
    update_cell(wb, "width", cfg.width)
    update_cell(wb, "height", cfg.height)

def write_band_data(cfg, band_data, wb, k_list, E_list2D):
    # BAND シートを作成・上書き
    ws_band = wb.create_sheet(title="BAND")

    ws_band.cell(row = 1, column = 1, value = "k")
    for ib in range(band_data.nband):
        ws_band.cell(row = 1, column = ib + 2, value = f"E(#{ib+1})")

    for ik in range(band_data.nk):
        ws_band.cell(row = ik + 2, column = 1, value = k_list[ik])
        for ib in range(band_data.nband):
            ws_band.cell(row = ik + 2, column = ib + 2, value = E_list2D[ib][ik])

def write_klines(cfg, band_data, wb, klabels_list):
    print()
    print("Add k lines")
    # KLINES シートを作成・上書き
    ws_klines = wb.create_sheet(title="KLINES")
    cfg.klabels_list = klabels_list
    ws_klines.cell(row=1, column=1, value="k(boundary)")
    ws_klines.cell(row=1, column=2, value="Emin/Emax")
    ik = 0
    for irow, row in enumerate(klabels_list[1:]):
        if len(row) != 2: continue
        if row[1] is None or row[1] == "": continue

        label, k = row[0], float(row[1])
        ws_klines.cell(row=3*ik + 2, column=1, value=k)
        ws_klines.cell(row=3*ik + 2, column=2, value=band_data.Emin_all)

        ws_klines.cell(row=3*ik + 3, column=1, value=k)
        ws_klines.cell(row=3*ik + 3, column=2, value=band_data.Emax_all)

        ik += 1
 
def write_klabels(cfg, wb, klabels_list):
    # KLABELS シートを作成・上書き
#    ws_klabels = wb.create_sheet(title="KLABELS")
#    ws_klabels = ws_klines
#    ncol_offset = 3
    ws_klabels = get_worksheet(wb, "Sheet1")

    ncol_offset = 0
    ws_klabels.cell(row=1, column=ncol_offset+1, value="k")
    ws_klabels.cell(row=1, column=ncol_offset+2, value="E")
    ws_klabels.cell(row=1, column=ncol_offset+3, value="label")
    for irow, row in enumerate(klabels_list[1:]):
        if len(row) != 2: continue
        if row[1] is None or row[1] == "": continue

        label, k = row
        ws_klabels.cell(row=irow+2, column=ncol_offset+1, value=k)
        ws_klabels.cell(row=irow+2, column=ncol_offset+2, value=cfg.Emin)
        ws_klabels.cell(row=irow+2, column=ncol_offset+3, value=label)

def write_bandgap_data(cfg, band_data, wb):
    # BAND_GAP シートを作成・上書き
    ws_gap = wb.create_sheet(title="BAND_GAP")
    ws_gap.cell(row=1, column=1, value="Eg")
    ws_gap.cell(row=2, column=1, value=band_data.Eg)
    
    ws_gap.cell(row=1, column=2, value="k")
    ws_gap.cell(row=1, column=3, value="EV (eV)")
    ws_gap.cell(row=2, column=2, value=band_data.kmin)
    ws_gap.cell(row=2, column=3, value=band_data.EV - band_data.EF)
    irow = 3
    for data in cfg.klabels_list:
        if len(data) != 2 or type(data[1]) is str or data[0] is None or data[1] is None: continue
        label, k = data
        ws_gap.cell(row=irow, column=2, value=k)
        ws_gap.cell(row=irow, column=3, value=band_data.EV - band_data.EF)
        irow += 1
    ws_gap.cell(row=irow, column=2, value=band_data.kmax)
    ws_gap.cell(row=irow, column=3, value=band_data.EV - band_data.EF)

    ws_gap.cell(row=1, column=4, value="EC (eV)")
    ws_gap.cell(row=2, column=4, value=band_data.EC - band_data.EF)
    irow = 3
    for data in cfg.klabels_list:
        if len(data) != 2 or type(data[1]) is str or data[0] is None or data[1] is None: continue
        label, k = data
        ws_gap.cell(row=irow, column=4, value=band_data.EC - band_data.EF)
        irow += 1
    ws_gap.cell(row=irow, column=4, value=band_data.EC - band_data.EF)
 
    ws_gap.cell(row=1, column=5, value="EF_original (eV)")
    ws_gap.cell(row=2, column=5, value=band_data.EF)

    ws_gap.cell(row=1, column=6, value="iHOMO")
    ws_gap.cell(row=2, column=6, value=band_data.iHOMO)
    ws_gap.cell(row=1, column=7, value="iLUMO")
    ws_gap.cell(row=2, column=7, value=band_data.iLUMO)

    ws_gap.cell(row=1, column=8, value="VBM position")
    if vars(band_data).get("VBM_positions", None):
        for idx, val in enumerate(band_data.VBM_positions):
            ws_gap.cell(row=2+idx, column=8, value=val)
    ws_gap.cell(row=1, column=9, value="CBM position")
    if vars(band_data).get("CBM_positions", None):
        for idx, val in enumerate(band_data.CBM_positions):
            ws_gap.cell(row=2+idx, column=9, value=val)

    return wb

def add_band_diagram(cfg, band_data, wb):
    chart = tkExcelChart('scatter')
    chart.set_title("Band structure")
    chart.set_xlabel("k")
    chart.set_ylabel("Energy (eV)")
    chart.set_figsize(cfg.width * cfg.kgraphsize, cfg.height * cfg.kgraphsize)
    chart.set_legend(None)

# E(k) のデータ
    ws_band = wb["BAND"] 
    max_row = ws_band.max_row
    max_col = ws_band.max_column
    for col in range(2, max_col + 1):
        chart.add_plot(ws_band, (1, 2, max_row), (col, 2, max_row),
                        title_from_data = False,
                        width = 0.5, color = "#000000", linestyle = "-")

# BZ境界のk点の線
    ws_klines = wb["KLINES"]
    chart.add_plot(ws_klines, (1, 2, 3 * band_data.nk + 1), (2, 2, 3 * band_data.nk + 1),
                        title_from_data = False,
                        width = 0.25, color = "#000000", linestyle = "-")

# EV, ECの位置を示す線
    ws_gap = wb["BAND_GAP"]
    max_rowbg = ws_gap.max_row
    chart.add_plot(ws_gap, (2, 2, max_rowbg), (3, 2, max_rowbg),
                        title_from_data = False,
                        width = 0.25, color = "red", linestyle = "dashed")
    chart.add_plot(ws_gap, (2, 2, max_rowbg), (4, 2, max_rowbg),
                        title_from_data = False,
                        width = 0.25, color = "red", linestyle = "dashed")

    chart.set_xlim([band_data.kmin, band_data.kmax])
    chart.set_ylim([cfg.Emin, cfg.Emax])

    chart.set_xgrid(None)
    chart.set_xticks(None)
    chart.set_ygrid(None)

    ws_main = get_worksheet(wb, "Sheet1")
    chart.add_chart(ws_main, cfg.chart_cell)

    return wb


def add_dos_sheet(cfg, wb, sheet_name, dos_inf):
    print(f"Add worksheet [{sheet_name}]")
    ws = wb.create_sheet(title = sheet_name)

    xlabel  = dos_inf["xlabel"]
    ylabels = dos_inf["ylabels"]
    E_list  = dos_inf["x_list"]
    dos_list = dos_inf["y_list"]
    print("labels=", xlabel, ylabels)

    ws.cell(row=1, column=1, value = "E (eV)")
    for i, l in enumerate(ylabels):
        ws.cell(row=1, column=i+2, value = l)

    for idata in range(len(E_list)):
        ws.cell(row=idata+2, column=1, value = E_list[idata])
        for icol in range(len(dos_list[idata])):
            ws.cell(row=idata+2, column=icol+2, value = dos_list[idata][icol])

def make_a_chart(plottype, title, xlabel, ylabel, figsize, xlim, ylim):
    chart = tkExcelChart(plottype)
    chart.set_title(title)
    chart.set_xlabel(xlabel)
    chart.set_ylabel(ylabel)
    chart.set_figsize(*figsize)
#    chart.set_legend(None)

    if xlim: chart.set_xlim(xlim)
    if ylim: chart.set_ylim(ylim)

#    chart.set_xgrid(None)
#    chart.set_xticks(None)
#    chart.set_ygrid(None)

    return chart
    
def add_dos_graph(cfg, dos_data, wb):
    figsize = (cfg.width * cfg.kgraphsize, cfg.height * cfg.kgraphsize)

#tDOS
    chart = make_a_chart("scatter", "Density of states", "Energy (eV)", "DOS", 
                         figsize, [cfg.Emin, cfg.Emax], None)
    ws = wb["TDOS"] 
    max_row = ws.max_row
    chart.add_plot(ws, (1, 2, max_row), (2, 2, max_row),
                        title_from_data = False, title = "TDOS",
                        width = 1.5, color = "#000000", linestyle = "-")

    for e in dos_data.elements:
        sheet_name = f"PDOS_{e}"
        ws = wb[sheet_name]
        max_row = ws.max_row
        icol = 11
        label = ws.cell(row=1, column=icol).value
        chart.add_plot(ws, (1, 2, max_row), (icol, 2, max_row),
                        title_from_data = False, title = label,
                        width = 0.5, color = None, linestyle = "-")

    chart.add_chart(wb.worksheets[0], cfg.chart_cell)

# PDOSs
    icol_skip = 9
    for ie, e in enumerate(dos_data.elements):
        sheet_name = f"PDOS_{e}"
        ws = wb[sheet_name]
        max_row = ws.max_row
        max_col = ws.max_column

        chart = make_a_chart("scatter", f"{e} density of states", "Energy (eV)", "DOS", 
                   (figsize[0] * 1.0, figsize[1] * 1.0), [cfg.Emin, cfg.Emax], None)

        for icol in range(2, max_col+1):
            label = ws.cell(row=1, column=icol).value
            if icol == max_col:
                width = 1.5
                color = "black"
            else:
                width = 0.5
                color = None

            chart.add_plot(ws, (1, 2, max_row), (icol, 2, max_row),
                        title_from_data = False, title = label,
                        width = width, color = color, linestyle = "-")

        col_name = get_column_letter(icol_skip * ie + 1)
        chart.add_chart(wb.worksheets[0], col_name + "27")


    return wb

def main():
    cfg = initialize()

    print()
    if cfg.mode == "band":
        print("Make band excel file from template\n")

        band_data = read_band_files(cfg)
        if band_data is None: terminate(cfg)

        print("nK=", band_data.nk)
        print("nBand=", band_data.nband)

        wb = read_template(cfg.template)
        reformat_workbook(wb)
        update_workbook(cfg, wb)

        write_band_data(cfg, band_data, wb, band_data.k_list, band_data.E_list2D)
        write_klines(cfg, band_data, wb, band_data.klabels_list)
        write_klabels(cfg, wb, band_data.klabels_list)
        write_bandgap_data(cfg, band_data, wb)
        add_band_diagram(cfg, band_data, wb)

        save(wb, cfg.output_excel)
    elif cfg.mode == "dos":
        print("Make dos excel file from template\n")

        dos_data = read_dos_files(cfg)
        if dos_data is None: terminate(cfg)

        wb = read_template(cfg.template)
        reformat_workbook(wb)
        update_workbook(cfg, wb)

        add_dos_sheet(cfg, wb, "TDOS", dos_data.TDOS)
        for e in dos_data.elements:
            dos_name = f"PDOS_{e}"
            add_dos_sheet(cfg, wb, dos_name,  getattr(dos_data, dos_name))
        add_dos_sheet(cfg, wb, "ITDOS", dos_data.ITDOS)

        add_dos_graph(cfg, dos_data, wb)

        save(wb, cfg.output_excel)
    else:
        terminate(cfg, f"\nError: Invalide mode [{cfg.mode}].\n")

    terminate(cfg)

if __name__ == "__main__":
    main()
