import os
import sys
import glob
from collections import OrderedDict
import math
from numpy import arange
from matplotlib import pyplot as plt

try:
    import tklib.tkimport as imp
except Exception as e:
    print()
    print("######################################################################")
    print("###########  ERROR ERROR ERROR ERROR ERROR ERROR #####################")
    print("######################################################################")
    print(f"# Failed to import [tklib.tkimport] module ({e}).")
    print(f"#  Add [tkProg]{os.sep}tklib{os.sep}python to PYTHONPATH variable")
    print(f"#  Current PYTHONPATH:", sys.path)
    print("######################################################################")
    input("Press ENTER to terminate>>\n")
    exit()

from tklib.tkcrystal.tkvasp import tkVASP


reference_dir = None


def pconv(s):
    """
    文字列を数値に変換する。変換できない場合はdefvalを返す。
    """
    try:
        return int(s)
    except:
        pass
    
    try:
        return float(s)
    except:
        return s

def check_difference(x1, x2, eps):
    diff = abs(x1 - x2)
    if diff < eps:
        print(f"   diff={diff:12.6g} < {eps:8.3g}: Sucess: {x1} vs {x2}")
        return True
    else:
        print(f"   diff={diff:12.6g} >={eps:8.3g}: Failed: {x1} vs {x2}")
        return False

def check_ratio(x1, x2, eps):
    dratio = abs((x1 - x2) / x1)
    if dratio < eps:
        print(f"   ratio={dratio:12.6g} < {eps:8.3g}: Sucess: {x1} vs {x2}")
        return True
    else:
        print(f"   ratio={dratio:12.6g} >={eps:8.3g}: Failed: {x1} vs {x2}")
        return False

def find_file_in_parent_dirs(dir, filename):
    """
    指定されたディレクトリから親ディレクトリを遡って、指定されたファイルを探す。
    見つかった場合はそのファイルのパスを返す。見つからない場合はNoneを返す。
    """
    current_dir = dir
    while True:
        file_path = os.path.join(current_dir, filename)
        if os.path.isfile(file_path):
            return file_path

        parent_dir = os.path.dirname(current_dir)
        if parent_dir == current_dir:  # ルートディレクトリに到達
            break
        current_dir = parent_dir

    return None

def read_condition_md(condition_md_path, print_level=1):
    """
    condition.mdファイルを読み込み、内容を辞書形式で返す。
    """
    if condition_md_path is None or not os.path.isfile(condition_md_path):
        print(f"  **Warning: condition.md [{condition_md_path}] not found")
        return {}

    inf = {}
    try:
        f = open(condition_md_path, 'r', encoding='utf-8', errors='replace')
    except Exception as e:
        print(f"Error: Failed to open condition.md file ({e})")
        return {}

    for line in f:
        line = line.strip()
        if not line or line.startswith('#'):
            continue  # 空行やコメント行はスキップ

        _a = line.split(':', 1)
        if len(_a) < 2:
            continue

        if '=' in _a[0]:
            key, value = line.split('=', 1)
        else:
            key, value = _a

        inf[key.strip()] = pconv(value.strip())

    f.close()

    if print_level > 1:
        print(f"Read condition.md from {condition_md_path}:")
        for key, value in inf.items():
            print(f"  {key}: {value}")

    return inf

def read_vasp_info(car_dir, print_level=1):
    condition_md_path = find_file_in_parent_dirs(car_dir, "condition.md")
    if print_level > 0:
        print(f"Read [condition.md path: ]{condition_md_path}]")
    inf_condition = read_condition_md(condition_md_path, print_level=print_level)

    vasp = tkVASP()

    inf = vasp.read_files(car_dir, ["INCAR", "POSCAR", "POTCAR", "KPOINTS", "CONTCAR", "OUTCAR", "DOSCAR", "EIGENVAL", "EIGENVAL_OPT"], 
                          EF = 0.0, normalize_E = True, unit = '/cm3', data_for_bandedges = None, 
                    exit_by_error = False, print_level = 0, terminate = None) 

    inf["condition_md"] = inf_condition
    inf_meta = {}
    inf_compare = {}
    inf_func = {}

    outcarinf = inf.get("OUTCAR", {})
    if outcarinf is None or outcarinf == {}: 
        return inf, inf_meta, inf_compare, inf_func

    for key in ["compiler", "mpi", "ncores_k", "ncores_bandf", "Total CPU time", "User time", "System time", "Elapsed time"]:
        val = outcarinf.get(key, None)
        if val is not None:
            inf_meta[key] = val
 
    """
    print()
    print("POSCAR:")
    cry = inf["POSCAR"]
    a, b, c, alpha, beta, gamm = cry.LatticeParameters()
    Vcell = cry.Volume()
    print("  cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamm))
    print("  volume: {:12.6f} A^-3".format(Vcell))
    """

    if "CONTCAR" in inf.keys():
        if print_level:
            print()
            print("CONTCAR:")
        cry = inf.get("crystal", None)
        if cry:
            a, b, c, alpha, beta, gamma = cry.LatticeParameters()
            Vcell = cry.Volume()
            for key in ['a', 'b', 'c', 'alpha', 'beta', 'gamma', "Vcell"]:
                inf_compare[key] = locals()[key] 
                inf_func[key] = lambda x1, x2, eps = 1.0e-6: check_difference(x1, x2, eps)

            if print_level:
                print("  cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamma))
                print("  volume: {:12.6f} A^-3".format(Vcell))

    if "OUTCAR" in inf.keys():
        if print_level:
            print()
            print("OUTCAR:")
    
        outcarinf = inf.get("OUTCAR", None)
        if outcarinf is None:
            return inf, inf_meta, inf_compare, inf_func

        ISPIN = outcarinf["ISPIN"]
        EF = outcarinf.get("EF", "")
        TOTEN = outcarinf.get("TOTEN", "")
#        final_charges = outcarinf.get("Final_charges", "")
        for key in ['EF', "TOTEN"]:
            inf_compare[key] = locals()[key] 
            inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        born_charges = outcarinf.get("born_charges", None)
        if born_charges:
            for iion, bc in enumerate(born_charges):
                for i in range(3):
                    for j in range(3):
                        key = f"born_charges[{iion}][{i}][{j}]"
                        inf_compare[key] = bc[i][j]
                        inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        piezo_static_localeff = outcarinf.get("piezo_static_localeff", None)
        if piezo_static_localeff:
            idx_str = ["XX", "YY", "ZZ", "XY", "YZ", "ZX"]
            for i in range(3):
                for j in range(6):
                    key = f"piezo_static[{i}][{idx_str[j]}]"
                    inf_compare[key] = piezo_static_localeff[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        eps_static_localeff = outcarinf.get("eps_static_localeff", None)
        if eps_static_localeff:
            for i in range(3):
                for j in range(3):
                    key = f"eps_static_e[{i}][{j}]"
                    inf_compare[key] = eps_static_localeff[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        eps_static_ionic = outcarinf.get("eps_static_ionic", None)
        if eps_static_ionic:
            for i in range(3):
                for j in range(3):
                    key = f"eps_static_ionic[{i}][{j}]"
                    inf_compare[key] = eps_static_ionic[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        if print_level:
            print(f"ISPIN: {ISPIN}")
            print(f"EF: {EF} eV")
            print(f"TOTEN: {TOTEN} eV")

    if "EIGENVAL" in inf.keys():
        if print_level:
            print()
            print("EIGENVAL:")
        eigenvalinf = inf.get("EIGENVAL", None)
        if eigenvalinf:
            nk      = eigenvalinf["nk"]
            nLevels = eigenvalinf["nLevels"]

            if print_level:
                print("k points in EIGENVAL:")
                print("nk=", nk)
                print("nLevels=", nLevels)

            bandedgeinf = vasp.find_band_edges_from_eigenval(EF0 = EF, eigenvalinf = eigenvalinf, ISPIN = ISPIN, print_level = print_level)
#修正jtanaka
            try:
                EV    = bandedgeinf["EV"]
                EC    = bandedgeinf["EC"]
                Eg    = bandedgeinf["Eg"]
                for key in ["EV", "EC", "Eg"]:
                    inf_compare[key] = locals()[key] 
                    inf_func[key] = lambda x1, x2, eps = 1.0e-6: check_difference(x1, x2, eps)

                if print_level:
                    print(f"EV: {EV} eV")
                    print(f"EC: {EC} eV")
                    print(f"Eg: {Eg} eV")
            except KeyError as e:
                # キーが見つからなかった場合に警告を表示し、処理を続行する
                # print_levelに関わらず、バンドエッジ取得失敗の警告は常に出力する
                print(f"    **Warning: Band edge analysis failed. Key not found: {e}.")
#修正ここまで

    """
    print()
    print("EIGENVAL_OPT:")
    eigenvaloptinf = inf["EIGENVAL_OPT"]
    nk      = eigenvaloptinf.get("nk", "")
    nLevels = eigenvaloptinf.get("nLevels", "")
    print("k points in EIGENVAL_OPT:")
    print("nk=", nk)
    print("nLevels=", nLevels)
    """

    """
    if "DOSCAR" in inf.keys():
        if print_level:
            print()
            print("DOSCAR:")

        dosinf = inf["DOSCAR"]
        E_raw   = dosinf["E"]
        dos_raw = dosinf["TotalDOS"]
        nDOS    = dosinf["nE"]
        Emin  = E_raw[0]
        Emax  = E_raw[nDOS-1]
        Estep = (E_raw[nDOS-1] - E_raw[0]) / (nDOS - 1)

        if print_level:
            print("  DOS E range: {} - {}, {} eV step".format(Emin, Emax, Estep))
    """

    return inf, inf_meta, inf_compare, inf_func

def compare(host, task, inf_dict, print_level = 1):
    ref = inf_dict["reference"].get(task, None)
    if ref is None:
        print(f"  **Warning: refence task [{task}] is not found in [{reference_dir}]. Skip.")
        return

    ref_comp = ref.get("inf_compare", None)
    if ref_comp is None:
        print(f"  ref_comp is not given for [{reference_dir}][{task}]. Skip.")
        return
    
    ref_func = ref.get("inf_func", None)
    if ref_comp is None:
        print(f"  ref_comp is given but ref_func is not given for [{reference_dir}][{task}]. Terminate.")
        exit()

    target      = inf_dict[host][task]
    target_comp = target.get("inf_compare", None)
    if target_comp is None:
        print(f"  targe_comp is not given for [{host}][{task}]. Skip.")
        return

    for key, val in target_comp.items():
        print(f"  for [{key:25}]:", end = "")
        ref_func[key](val, ref_comp[key])

def main():
    global reference_dir

    argv = sys.argv
    nargs = len(argv)

    dirs = OrderedDict()
    for pattern in argv[1:]:
        for path in glob.glob(pattern):
            dirs[path] = None
    dirs = list(dirs)
    print("Search dirs:", dirs)

    print()
    print(f"Test VASP results in {dirs}")
    count = 0
    inf_dict = {}
    for idx, dir in enumerate(dirs):
        if not os.path.isdir(dir):
            print(f"Warning: [{dir}] is not a directory.")
            continue

        print()
        print(f"{idx:03d}: In [{dir}]")
        if idx == 0:
           host = "reference"
           reference_dir = dir
           print(f"  Used as reference:")
        else:
            host = dir
            print(f"  Compared with reference in [{dirs[0]}]")
    
        inf_dict[host] = {}
        for root, _, files in os.walk(dir):
            if "INCAR" in files:
                incar_path = os.path.join(root, "INCAR")
                # rootの先頭からdir+'/'に一致するパスを削除してrel_dirに代入
                rel_dir = os.path.relpath(root, dir)
                task = rel_dir
                inf_dict[host][task] = {}
                print()
                print(f"  [{task}]:")

#修正jtanaka                
                try:
                    inf, inf_meta, inf_compare, inf_func = read_vasp_info(root, print_level=0)
                except IndexError as e:  # DOSCAR がなくて失敗するケース
                    print(f"[Warning] DOSCAR could not be read in {root}: {e}")
                    # 空データを入れて後続処理が止まらないようにする
                    inf, inf_meta, inf_compare, inf_func = {}, {}, {}, {}
                except Exception as e:  # その他の予期せぬエラーにも対応
                    print(f"[Warning] Failed to read VASP info in {root}: {e}")
                    inf, inf_meta, inf_compare, inf_func = {}, {}, {}, {}
#修正ここまで

                if inf_compare:
                    for key, value in inf_meta.items():
                        print(f"    {key}: {value}")

                    inf_dict[host][task]["inf"] = inf
                    inf_dict[host][task]["inf_meta"] = inf_meta
                    inf_dict[host][task]["inf_compare"] = inf_compare
                    inf_dict[host][task]["inf_func"] = inf_func

                    if idx > 0:
                        compare(host, task, inf_dict, print_level = 1)

                    count += 1
            else:
#                print(f"No INCAR file found in {root}")
                pass
 
    exit()


if __name__ == "__main__":
    main()
