import os
import sys
import types
import pandas as pd
from openpyxl import load_workbook
from openpyxl.chart import ScatterChart, Reference, Series
from openpyxl.chart.shapes import GraphicalProperties
from openpyxl.drawing.line import LineProperties
#from openpyxl.drawing.text import CharacterProperties, Font as OxmlFont
#from openpyxl.chart.text import RichText, Paragraph


# ギリシャ文字マッピング辞書
GREEK_MAP = {
    "GAMMA": "Γ",
    "ALPHA": "Α",
    "BETA": "Β",
    "DELTA": "Δ",
    "EPSILON": "Ε",
    "ZETA": "Ζ",
    "ETA": "Η",
    "THETA": "Θ",
    "IOTA": "Ι",
    "KAPPA": "Κ",
    "LAMDA": "Λ",
    "MU": "Μ",
    "NU": "Ν",
    "XI": "Ξ",
    "OMICRON": "Ο",
    "PI": "Π",
    "RHO": "Ρ",
    "SIGMA": "Σ",
    "TAU": "Τ",
    "UPSILON": "Υ",
    "PHI": "Φ",
    "CHI": "Χ",
    "PSI": "Ψ",
    "OMEGA": "Ω",

    "gamma": "γ",
    "alpha": "α",
    "beta": "β",
    "delta": "δ",
    "epsilon": "ε",     
    "zeta": "ζ",
    "eta": "η",
    "theta": "θ",
    "iota": "ι",
    "kappa": "κ",
    "lamda": "λ",
    "mu": "μ",
    "nu": "ν",
    "xi": "ξ",
    "omicron": "ο",
    "pi": "π",  
    "rho": "ρ",
    "sigma": "σ",
    "tau": "τ",
    "upsilon": "υ",
    "phi": "φ",
    "chi": "χ",
    "psi": "ψ",
    "omega": "ω"       
}


def get_worksheet(wb, priority_name = None):
    if priority_name and priority_name in wb.sheetnames:
        return wb[priority_name]
    return wb.worksheets[0]

def update_cell(wb, cell_name, val):
    lower = cell_name.lower()
    for name in wb.defined_names:
        if name.lower() == lower:
            dn = wb.defined_names[name]
            for sheet_name, cell_addr in dn.destinations:
                ws = wb[sheet_name]
                ws[cell_addr].value = val
            return True
    return False

def convert_color(color):
    if color is None: return None

    color_dict = {
        "black": "000000",
        "red"  : "FF0000",
        "green": "00FF00",
        "blue" : "0000FF",
        }

    if color in color_dict.keys():
        return color_dict[color]

    return color.lstrip("#")

def convert_linestyle(linestyle):
    if linestyle == "dashed":
        return "lgDash"
    else:
        return "solid"

# point to EMU (English Metric Unit)
# 1 inch = 914400 EMU 
# 1 pt = 1/72 inch、
def convert_size(s):
    return int(12700 * s)

class tkExcelChart():
    def __init__(self, style = 'scatter'):
        if style == 'scatter':
            self.chart = ScatterChart()
        else:
            self.chart = False

    def add_chart(self, ws, position = "A1"):
        ws.add_chart(self.chart, position)

    def set_figsize(self, width, height):
        self.chart.width = width
        self.chart.height = height

    def set_title(self, title):
        self.chart.title = title

    def set_xlabel(self, label):
        self.chart.x_axis.title = label

    def set_ylabel(self, label):
        self.chart.y_axis.title = label

    def set_legend(self, style = ''):
        if style is None:
            self.chart.legend = None

    def add_plot(self, ws, xrange, yrange, title_from_data = False, title = None,
                 width = None, color = "000000", linestyle = "-"):
        ls = convert_linestyle(linestyle)
        cl = convert_color(color)
        w  = convert_size(width)

        x_ref = Reference(ws, min_col = xrange[0], min_row = xrange[1], max_row = xrange[2])
        y_ref = Reference(ws, min_col = yrange[0], min_row = yrange[1], max_row = yrange[2])
        series = Series(y_ref, x_ref, title_from_data = title_from_data, title = title)
        series.graphicalProperties.line.width     = w
        series.graphicalProperties.line.solidFill = cl
        series.graphicalProperties.line.prstDash  = ls
        self.chart.series.append(series)

    def set_xlim(self, range):
        self.chart.x_axis.scaling.min = range[0]
        self.chart.x_axis.scaling.max = range[1]

    def set_ylim(self, range):
        self.chart.y_axis.scaling.min = range[0]
        self.chart.y_axis.scaling.max = range[1]

    def set_xgrid(self, style = None):
        self.chart.x_axis.majorGridlines = style

    def set_ygrid(self, style = None):
        self.chart.y_axis.majorGridlines = style

    def set_xticks(self, style = "none"):
        if style is None: style = "none"
        self.chart.x_axis.majorTickMark = style

    def set_xticks(self, style = "none"):
        if style is None: style = "none"
        self.chart.x_axis.minorTickMark = style


##################################    
# 以下は機能しない
    """
# Y軸目盛り線を内向きに
    chart.y_axis.majorTickMark = "in"
    chart.y_axis.minorTickMark = "in"
# 目盛りラベルは表示したい場合は次のように
    chart.x_axis.tickLblPos = "nextTo"
    chart.y_axis.tickLblPos = "nextTo"
# チャート全体の枠線を2ptの黒実線に
# 1pt = 12700 EMU → 2pt = 25400 EMU
    border = GraphicalProperties(
        ln=LineProperties(w=25400, solidFill="000000")
    )   
    chart.spPr = border

# --- フォント設定用オブジェクト作成 ---
    excel_font = OxmlFont(typeface="ＭＳ ゴシック")
    font_prop   = CharacterProperties(
        latin=excel_font,
        ea=excel_font,
        cs=excel_font,
        sz=cfg.fontsize * 100
    )   
    para      = Paragraph(endParaRPr=font_prop)
    rich_txt  = RichText(p=[para])

# --- タイトル部分のフォント適用 ---
    chart.title.tx.rich        = rich_txt
    chart.x_axis.title.tx.rich = rich_txt
    chart.y_axis.title.tx.rich = rich_txt
    """
##################################    


def save(wb, outfile):
    try:
        wb.save(outfile)
        print(f"グラフを追加し、Excelファイルに保存しました: {outfile}")
    except Exception as e:
        print(f"Error: Excelファイル [{outfile}] の保存中にエラーが発生しました: {e}")
        return False

    return True

def read_template(infile):
    if not os.path.exists(infile):
        print(f"Error: テンプレートファイルが見つかりません: {infile}")
        return None

    print(f"テンプレートファイルを読み込みます: {infile}")
    wb = load_workbook(infile, keep_vba=True)

    return wb

def read_band_dat(filepath, band_data):
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません: {filepath}")
        return None

    print(f"band.datファイルを読み込みます: {filepath}")

    with open(filepath, "r") as file:
        lines = file.readlines()

    k_all = []
    E_all = []
    E_list2D = []
    
    _a = lines[1].split(':')
    _a2 = _a[1].split()
    band_data.nk = int(_a2[0])
    band_data.nband = int(_a2[1])

    idx = 3
    for ib in range(band_data.nband):
        data = []
        for ik in range(band_data.nk):
            _a = lines[idx].split()
            data.append([float(_a[0]), float(_a[1])])
            idx += 1

        #dataをdata[0]でソート
        data.sort(key=lambda x: x[0])
        k_list = [float(k) for k, E in data]
        E_list = [float(E) for k, E in data]

        if ib == 0:
            k_list0 = k_list.copy()
        else:
            if k_list != k_list0:
                print(f"Error: k_list と k_list0 の要素が異なります: {k_list} != {k_list0}")
                sys.exit(1)

        E_list2D.append(E_list)

        idx += 2    # 空行と Band-indexをスキップ

    band_data.k_list = k_list
    band_data.E_list2D = E_list2D
    band_data.kmin, band_data.kmax = min(k_list), max(k_list)
    band_data.Emin_all, band_data.Emax_all = min(map(min, E_list2D)), max(map(max, E_list2D))
    print("kmin,kmax=", band_data.kmin, band_data.kmax)
    print("Emin,Emax=", band_data.Emin_all, band_data.Emax_all)

    return band_data

def read_klines(filepath):
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        return None

    print(f"{filepath}ファイルを読み込みます")
    klines_list = []
    with open(filepath, newline="") as fp:
        for row in fp:
            if not row: continue

            _a = row.split()
            klines_list.append(float(_a[0]), float(_a[1]))
    if not klines_list:
        print(f"Error: {filepath} にデータがありません")
        sys.exit(1)
    return klines_list


def read_klabels(filepath):
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        return None

    print(f"{filepath}ファイルを読み込みます")
    klabels_list = []
    with open(filepath, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i == 0: continue

            _a = line.strip()
            if len(_a) < 2: break

            _a2 = _a.split()
            if len(_a2) < 2: continue

            label = _a2[0]
            val = float(_a2[1])

            for key in GREEK_MAP.keys():
                if key in label:
                    label = label.replace(key, GREEK_MAP[key])

            klabels_list.append([label, val])

    return klabels_list


def read_band_gap(filepath, band_data):
    if not os.path.exists(filepath):
        print(f"Error: {filepath} が見つかりません")
        return None

    print(f"BAND_GAPファイルを読み込みます: {filepath}")
    with open(filepath, "r", encoding="utf-8") as f:
        lines = f.readlines()

    for i, line in enumerate(lines):
        if ":" not in line: continue

        key, val = line.split(":", 1)
        key = key.strip()
        val = val.strip().split()
        if 'Character' in key:
            band_data.TransitionType = val[0]
        elif 'Band Gap' in key:
            band_data.Eg = float(val[0])
        elif 'Eigenvalue' in key and 'VBM' in key:
            band_data.EV = float(val[0])
        elif 'Eigenvalue' in key and 'CBM' in key:
            band_data.EC = float(val[0])
        elif 'Fermi Energy' in key:
            band_data.EF = float(val[0])
        elif 'Location' in key and 'VBM' in key:
            band_data.VBM_positions = [float(v) for v in val]
        elif 'Location' in key and 'CBM' in key:
            band_data.CBM_positions = [float(v) for v in val]
        elif 'HOMO' in key:
            band_data.iHOMO = int(float(val[0]) + 0.0001)
            band_data.iLUMO = int(float(val[1]) + 0.0001)
 
    return band_data

def read_poscar(path, dos_data):
    if not os.path.exists(path):
        print(f"Error: Can not find {path}")
        return None

    print(f"Read [{path}]")
    with open(path, "r", encoding="utf-8") as f:
        lines = f.readlines()

    dos_data.elements = lines[5].split()
    
    return dos_data

def read_doscar(path, element, dos_data):
    if not os.path.exists(path):
        print(f"Error: Can not find {path}")
        return None

    print(f"Read [{path}]")
    df = pd.read_csv(path, sep = r'\s+', header=0)
    ylabels = df.columns[1:].tolist()
    if element:
        for i in range(len(ylabels)):
            ylabels[i] = f"{element} {ylabels[i]}"
    
    inf = {
        "xlabel" : df.columns[0],
        "ylabels": ylabels,
        "x_list" : df.iloc[:, 0].tolist(),
        "y_list" :df.iloc[:, 1:].values.tolist(),
        }
    return inf

def read_dos_files(cfg):
    dos_data = types.SimpleNamespace()
    dos_data = read_poscar(cfg.poscar, dos_data)
    if dos_data is None: return None
    print("elements:", dos_data.elements)

    dos_data.TDOS = read_doscar(os.path.join(cfg.work_dir, "TDOS.dat"), None, dos_data)
    if dos_data.TDOS is None: return None
    dos_data.ITDOS = read_doscar(os.path.join(cfg.work_dir, "ITDOS.dat"), None, dos_data)
    if dos_data.ITDOS is None: return None

    for e in dos_data.elements:
        fbody = f"PDOS_{e}"
        inf = read_doscar(os.path.join(cfg.work_dir, f"{fbody}.dat"), e, dos_data)
        setattr(dos_data, fbody, inf)

    return dos_data

def read_band_files(cfg):
    band_data = types.SimpleNamespace()
    band_data = read_band_dat(cfg.band_dat, band_data)
    if band_data is None: return None

    band_data.klabels_list = read_klabels(cfg.klabels)
    if band_data.klabels_list is None: return None
    band_data = read_band_gap(cfg.band_gap, band_data)
    if band_data is None: return None

    return band_data


def main():
    pass


if __name__ == "__main__":
    main()
