import os
import glob
import re
from collections import OrderedDict

from tklib.tkcrystal.tkvasp import tkVASP


image_exts = ['.png', '.ico', '.jpeg', '.jpg', '.bmp']
text_exts  = ['.txt', '.md']
html_exts  = ['.html', '.htm']
csv_exts   = ['.csv']
excel_exts = ['.xlsx', '.xlsm']
word_exts  = [".docx"]
pdf_exts   = [".pdf"]

def search_reference_files(base_dir, search_dir):
    fmask = os.path.join(search_dir, "*")
    files = sorted(glob.glob(fmask, recursive=False))

    image_files = []
    text_files = []
    html_files = []
    pdf_files = []

    files_list = []
    for f in files:
    #fの拡張子を取得
        ext = os.path.splitext(f)[1].lower()
        #fのパスを、base_dirに対する相対パスにする
        rel_path = os.path.relpath(f, base_dir)
        if ext in image_exts:
            image_files.append(rel_path)
        elif ext in text_exts:
            text_files.append(rel_path)
        elif ext in html_exts:
            html_files.append(rel_path)
        elif ext in pdf_exts:
            pdf_files.append(rel_path)
            
    return { "image_files": image_files, "text_files": text_files, 
             "html_files": html_files, "pdf_files": pdf_files }

def flatten_dict(d, sep=':', parent_key = '', escape_keys = True, prefix = ""):
    """
    階層的な辞書をフラットな辞書に変換する関数。
    """
    items = {}
    for k, v in d.items():
        if parent_key is None or parent_key == "":
            new_key = k
        else:
            new_key = f"{parent_key}{sep}{k}"
            

        if prefix != "":
            new_key = f"{prefix}{sep}{new_key}"

        if escape_keys:
            new_key = re.sub(r'[^a-zA-Z0-9_]', '_', new_key)

        if isinstance(v, dict):
            items.update(flatten_dict(v, sep, new_key, True, ""))
        else:
            items[new_key] = v
    return items

def is_cardir(path):
    if not os.path.isdir(path): return False
    if not os.path.isfile(os.path.join(path, 'INCAR')): 
        return False
    if not os.path.isfile(os.path.join(path, 'POSCAR')):
        return False
    if not os.path.isfile(os.path.join(path, 'OUTCAR')):
        return False
    return True


def read_inf(car_dir, print_level=1):
    vasp = tkVASP()

#    inf = vasp.read_files(car_dir, ["INCAR", "POSCAR", "POTCAR", "KPOINTS", "CONTCAR", "VASPRUN", "OUTCAR", "DOSCAR", "EIGENVAL", "EIGENVAL_OPT"], 
    inf = vasp.read_files(car_dir, ["INCAR", "POSCAR", "POTCAR", "KPOINTS", "CONTCAR", "OUTCAR", "DOSCAR", "EIGENVAL", "EIGENVAL_OPT"], 
                          EF = 0.0, normalize_E = True, unit = '/cm3', data_for_bandedges = None, 
                    exit_by_error = False, print_level = 1, terminate = None) 

    inf_meta = {}
    inf_compare = {}
    inf_func = {}

    outcarinf = inf.get("OUTCAR", {})
    if outcarinf is None or outcarinf == {}: 
        return inf, inf_meta, inf_compare, inf_func

    for key in ["compiler", "mpi", "ncores_k", "ncores_bandf", "Total CPU time", "User time", "System time", "Elapsed time"]:
        val = outcarinf.get(key, None)
        if val is not None:
            inf_meta[key] = val

    print()
    print("POSCAR:")
    inf_poscar = inf["POSCAR"]
    cry = inf_poscar["crystal"]
    a, b, c, alpha, beta, gamma = cry.LatticeParameters()
    Vcell = cry.Volume()
    print("  cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamma))
    print("  volume: {:12.6f} A^-3".format(Vcell))
    inf["InitialStructure:LatticeParameters"] = inf["POSCAR"]["LatticeParameters"]
    inf["InitialStructure:Vcell"]      = inf_poscar["Vcell"]
    inf["InitialStructure:nAtomTypes"] = inf_poscar["nAtomTypes"]
    inf["InitialStructure:AtomTypes"]  = inf_poscar["AtomTypes"]
    inf["InitialStructure:nAtomSites"] = inf_poscar["nAtomSites"]
    inf["InitialStructure:AtomSites"]  = inf_poscar["AtomSites"]

    if "CONTCAR" in inf.keys():
        if print_level:
            print()
            print("CONTCAR:")
        inf_contcar = inf.get("CONTCAR", None)
        if inf_contcar:
            cry = inf_contcar["crystal"]
            a, b, c, alpha, beta, gamma = cry.LatticeParameters()
            Vcell = cry.Volume()
            for key in ['a', 'b', 'c', 'alpha', 'beta', 'gamma', "Vcell"]:
                inf_compare[key] = locals()[key] 
                inf_func[key] = lambda x1, x2, eps = 1.0e-6: check_difference(x1, x2, eps)

            if print_level:
                print("  cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamma))
                print("  volume: {:12.6f} A^-3".format(Vcell))

            inf["FinalStructure:LatticeParameters"] = inf_contcar["LatticeParameters"]
            inf["FinalStructure:Vcell"]      = inf_contcar["Vcell"]
            inf["FinalStructure:nAtomTypes"] = inf_contcar["nAtomTypes"]
            inf["FinalStructure:AtomTypes"]  = inf_contcar["AtomTypes"]
            inf["FinalStructure:nAtomSites"] = inf_contcar["nAtomSites"]
            inf["FinalStructure:AtomSites"]  = inf_contcar["AtomSites"]

    if "OUTCAR" in inf.keys():
        if print_level:
            print()
            print("OUTCAR:")
    
        outcarinf = inf.get("OUTCAR", None)
        if outcarinf is None:
            return inf, inf_meta, inf_compare, inf_func

        ISPIN = outcarinf["ISPIN"]
        EF = outcarinf.get("EF", "")
        TOTEN = outcarinf.get("TOTEN", "")
#        final_charges = outcarinf.get("Final_charges", "")
        for key in ['EF', "TOTEN"]:
            inf_compare[key] = locals()[key] 
            inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        born_charges = outcarinf.get("born_charges", None)
        if born_charges:
            for iion, bc in enumerate(born_charges):
                for i in range(3):
                    for j in range(3):
                        key = f"born_charges[{iion}][{i}][{j}]"
                        inf_compare[key] = bc[i][j]
                        inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        piezo_static_localeff = outcarinf.get("piezo_static_localeff", None)
        if piezo_static_localeff:
            idx_str = ["XX", "YY", "ZZ", "XY", "YZ", "ZX"]
            for i in range(3):
                for j in range(6):
                    key = f"piezo_static[{i}][{idx_str[j]}]"
                    inf_compare[key] = piezo_static_localeff[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        eps_static_localeff = outcarinf.get("eps_static_localeff", None)
        if eps_static_localeff:
            for i in range(3):
                for j in range(3):
                    key = f"eps_static_e[{i}][{j}]"
                    inf_compare[key] = eps_static_localeff[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        eps_static_ionic = outcarinf.get("eps_static_ionic", None)
        if eps_static_ionic:
            for i in range(3):
                for j in range(3):
                    key = f"eps_static_ionic[{i}][{j}]"
                    inf_compare[key] = eps_static_ionic[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        if print_level:
            print(f"ISPIN: {ISPIN}")
            print(f"EF: {EF} eV")
            print(f"TOTEN: {TOTEN} eV")

    if "EIGENVAL" in inf.keys():
        if print_level:
            print()
            print("EIGENVAL:")
        eigenvalinf = inf.get("EIGENVAL", None)
        if eigenvalinf:
            nk      = eigenvalinf["nk"]
            nLevels = eigenvalinf["nLevels"]

            if print_level:
                print("k points in EIGENVAL:")
                print("nk=", nk)
                print("nLevels=", nLevels)

            bandedgeinf = vasp.find_band_edges_from_eigenval(EF0 = EF, eigenvalinf = eigenvalinf, ISPIN = ISPIN, print_level = print_level)
            if bandedgeinf.get('EV', None) is None:
                print(f"Warning in read_inf.read_inf(): Could not get band edges")
                
                inf["EIGENVAL:EV"] = None
                inf["EIGENVAL:EC"] = None
                inf["EIGENVAL:Eg"] = None
            else:
                EV    = bandedgeinf["EV"]
                EC    = bandedgeinf["EC"]
                Eg    = bandedgeinf["Eg"]
                for key in ["EV", "EC", "Eg"]:
                    inf_compare[key] = locals()[key] 
                    inf_func[key] = lambda x1, x2, eps = 1.0e-6: check_difference(x1, x2, eps)
            
                inf["EIGENVAL:EV"] = EV
                inf["EIGENVAL:EC"] = EC
                inf["EIGENVAL:Eg"] = Eg

                if print_level:
                    print(f"EV: {EV} eV")
                    print(f"EC: {EC} eV")
                    print(f"Eg: {Eg} eV")

    """
    print()
    print("EIGENVAL_OPT:")
    eigenvaloptinf = inf["EIGENVAL_OPT"]
    nk      = eigenvaloptinf.get("nk", "")
    nLevels = eigenvaloptinf.get("nLevels", "")
    print("k points in EIGENVAL_OPT:")
    print("nk=", nk)
    print("nLevels=", nLevels)
    """

    """
    if "DOSCAR" in inf.keys():
        if print_level:
            print()
            print("DOSCAR:")

        dosinf = inf["DOSCAR"]
        E_raw   = dosinf["E"]
        dos_raw = dosinf["TotalDOS"]
        nDOS    = dosinf["nE"]
        Emin  = E_raw[0]
        Emax  = E_raw[nDOS-1]
        Estep = (E_raw[nDOS-1] - E_raw[0]) / (nDOS - 1)

        if print_level:
            print("  DOS E range: {} - {}, {} eV step".format(Emin, Emax, Estep))
    """

    return inf
#    return inf, inf_meta, inf_compare, inf_func

def read_inf_all(base_dir, subdirs, args, print_level = 1):
    if subdirs is None or subdirs == "":
        files_list = search_reference_files(
                    base_dir = base_dir, search_dir = base_dir)
#        print("files_list=", files_list)
        replace_dict = {}
        replace_dict["ImageFiles"] = files_list["image_files"]
        replace_dict["TextFiles"]  = files_list["text_files"]
        replace_dict["HTMLFiles"]  = files_list["html_files"]
        replace_dict["PDFFiles"]   = files_list["pdf_files"]

        inf = read_inf(base_dir, print_level = print_level)
        replace_dict2 = flatten_dict(inf, escape_keys = True, prefix = args.prefix)
        replace_dict.update(replace_dict2)

    else:
        replace_dict = {}
        replace_dict["base_dir"]         = base_dir
        replace_dict["base_path"]        = os.path.abspath(base_dir)
        replace_dict["subdirs"]          = subdirs
        replace_dict["subdirs_searched"] = None
        replace_dict["prefix"]           = args.prefix
        replace_dict["template_path"]    = args.template_file
        replace_dict["SYSTEM_Source"]    = None
        replace_dict["SYSTEM"]           = None
        replace_dict["InitialStructure_Source"] = None
        replace_dict["InitialStructure_LatticeParameters"] = None
        replace_dict["InitialStructure_Vcell"]      = None
        replace_dict["InitialStructure_nAtomTypes"] = None
        replace_dict["InitialStructure_AtomTypes"]  = None
        replace_dict["InitialStructure_nAtomSites"] = None
        replace_dict["InitialStructure_AtomSites"]  = None
        replace_dict["RelaxedStructure_Source"]     = None
        replace_dict["RelaxedStructure_LatticeParameters"] = None
        replace_dict["RelaxedStructure_Vcell"]      = None
        replace_dict["RelaxedStructure_nAtomTypes"] = None
        replace_dict["RelaxedStructure_AtomTypes"]  = None
        replace_dict["RelaxedStructure_nAtomSites"] = None
        replace_dict["RelaxedStructure_AtomSites"]  = None
        
        files_list = search_reference_files(base_dir = base_dir, search_dir = base_dir)
#        print("files_list=", files_list)
        replace_dict["ImageFiles"] = files_list["image_files"]
        replace_dict["TextFiles"]  = files_list["text_files"]
        replace_dict["HTMLFiles"]  = files_list["html_files"]
        replace_dict["PDFFiles"]   = files_list["pdf_files"]

        _a = subdirs.split(",")
        files_dict = OrderedDict()
        for dir in _a:
            print(f"**split subdirs [{dir}]")
            fmask = os.path.join(base_dir, dir)
            for file in glob.glob(fmask):
                files_dict[file] = ""
        dirs = list(files_dict.keys())
        dirs = [path for path in dirs if is_cardir(path)]

        print()
        print("base_dir:", base_dir)
        print("Sub directories to search:", dirs)
        inf_dict = {}
        subdirs_inf = []
        subdirs_inf_added = {}
        for d in dirs:
            car_dir = d
#            car_dir = os.path.join(base_dir, d)

            files_list = search_reference_files(base_dir = base_dir, search_dir = car_dir)
            inf_base = {}
            inf_base["ImageFiles"] = files_list["image_files"]
            inf_base["TextFiles"]  = files_list["text_files"]
            inf_base["HTMLFiles"]  = files_list["html_files"]
            inf_base["PDFFiles"]   = files_list["pdf_files"]

            inf2 = read_inf(car_dir, print_level=1)
            if inf2:
                prefix = os.path.basename(d)
                fullpath = os.path.abspath(car_dir)
                directory = os.path.dirname(fullpath)
                base_name = os.path.basename(fullpath)
#                print(f"{prefix=}")
#                print(f"{fullpath=}")
#                print(f"{directory=}")
#                print(f"{base_name=}")
                if subdirs_inf_added.get(fullpath, None) is None:
                    subdirs_inf.append({ "path": fullpath, "dir": directory, "last_dir": base_name })
                    subdirs_inf_added[fullpath] = True

#inf2のkeyは、Jinja2互換ではない (:などが使われている)
                inf_dict[prefix] = inf2
           
            inf_base.update(inf2)
#replace_dictのkeyは、Jinja2互換になる
            replace_dict2 = flatten_dict(inf_base, escape_keys = True, prefix = prefix)
            replace_dict.update(replace_dict2)

        replace_dict["subdirs_searched"] = subdirs_inf

        print()
        print(f"Search global parameters from {list(inf_dict.keys())}:")
# globalな設定 (初期構造をどのsubdirからとるかなど)
        inf_incar = None
        for key in ["MD", "VCRelax1", "VCRelax", "SCFDFT", "SCF", "DOS", "BandX", "Phonon"]:
            if inf_dict.get(key, None):
                print(f"  SYSTEM is taken from [{key}]")
                inf_incar = inf_dict.get(key)
                break
        replace_dict["SYSTEM_Source"] = key
        replace_dict["SYSTEM"] = inf_incar["INCAR"]["SYSTEM"]

        inf_istruct = None
        for key in ["MD", "VCRelax1", "VCRelax", "SCFDFT", "SCF", "DOS", "BandX", "Phonon"]:
            if inf_dict.get(key, None):
                print(f"  Initial structure is taken from [{key}]")
                inf_istruct = inf_dict.get(key)
                break
        if inf_istruct:
            replace_dict["InitialStructure_Source"] = key
            replace_dict["InitialStructure_LatticeParameters"] = inf_istruct["InitialStructure:LatticeParameters"]
            replace_dict["InitialStructure_Vcell"]      = inf_istruct["InitialStructure:Vcell"]
            replace_dict["InitialStructure_nAtomTypes"] = inf_istruct["InitialStructure:nAtomTypes"]
            replace_dict["InitialStructure_AtomTypes"]  = inf_istruct["InitialStructure:AtomTypes"]
            replace_dict["InitialStructure_nAtomSites"] = inf_istruct["InitialStructure:nAtomSites"]
            replace_dict["InitialStructure_AtomSites"]  = inf_istruct["InitialStructure:AtomSites"]

        inf_rstruct = None
        for key in [f"VCRelax", "VCRelax1"]:
#            print("check fin: ", key)
            if inf_dict.get(key, None):
                print(f"  Relaxed structure is taken from [{key}]")
                inf_rstruct = inf_dict.get(key)
                break
#        print("inf_rstruct:", inf_rstruct)
        if inf_rstruct:
            replace_dict["RelaxedStructure_Source"] = key
            replace_dict["RelaxedStructure_LatticeParameters"] = inf_rstruct["FinalStructure:LatticeParameters"]
            replace_dict["RelaxedStructure_Vcell"]      = inf_rstruct["FinalStructure:Vcell"]
            replace_dict["RelaxedStructure_nAtomTypes"] = inf_rstruct["FinalStructure:nAtomTypes"]
            replace_dict["RelaxedStructure_AtomTypes"]  = inf_rstruct["FinalStructure:AtomTypes"]
            replace_dict["RelaxedStructure_nAtomSites"] = inf_rstruct["FinalStructure:nAtomSites"]
            replace_dict["RelaxedStructure_AtomSites"]  = inf_rstruct["FinalStructure:AtomSites"]

    return replace_dict
    