from tklib.tkcrystal.tkvasp import tkVASP


def read_vasp_info(car_dir, print_level=1):
    vasp = tkVASP()

    inf = vasp.read_files(car_dir, ["INCAR", "POSCAR", "POTCAR", "KPOINTS", "CONTCAR", "OUTCAR", "DOSCAR", "EIGENVAL", "EIGENVAL_OPT"], 
                          EF = 0.0, normalize_E = True, unit = '/cm3', data_for_bandedges = None, 
                    exit_by_error = False, print_level = 0, terminate = None) 

    inf_meta = {}
    inf_compare = {}
    inf_func = {}

    outcarinf = inf.get("OUTCAR", {})
    if outcarinf is None or outcarinf == {}: 
        return inf, inf_meta, inf_compare, inf_func

    for key in ["compiler", "mpi", "ncores_k", "ncores_bandf", "Total CPU time", "User time", "System time", "Elapsed time"]:
        val = outcarinf.get(key, None)
        if val is not None:
            inf_meta[key] = val
 
    print()
    print("POSCAR:")
    cry = inf["crystal"]
    a, b, c, alpha, beta, gamma = cry.LatticeParameters()
    Vcell = cry.Volume()
    print("  cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamma))
    print("  volume: {:12.6f} A^-3".format(Vcell))
    inf["InitialStructure:a"] = a
    inf["InitialStructure:b"] = b
    inf["InitialStructure:c"] = c
    inf["InitialStructure:alpha"] = alpha
    inf["InitialStructure:beta"] = beta
    inf["InitialStructure:gamma"] = gamma

    if "CONTCAR" in inf.keys():
        if print_level:
            print()
            print("CONTCAR:")
        cry = inf.get("CONTCAR", None)
        if cry:
            a, b, c, alpha, beta, gamma = cry.LatticeParameters()
            Vcell = cry.Volume()
            for key in ['a', 'b', 'c', 'alpha', 'beta', 'gamma', "Vcell"]:
                inf_compare[key] = locals()[key] 
                inf_func[key] = lambda x1, x2, eps = 1.0e-6: check_difference(x1, x2, eps)

            if print_level:
                print("  cell: {:12.8f} {:12.8f} {:12.8f} A   {:10.6f} {:10.6f} {:10.6f}".format(a, b, c, alpha, beta, gamma))
                print("  volume: {:12.6f} A^-3".format(Vcell))

        inf["RelaxedStructure:a"] = a
        inf["RelaxedStructure:b"] = b
        inf["RelaxedStructure:c"] = c
        inf["RelaxedStructure:alpha"] = alpha
        inf["RelaxedStructure:beta"] = beta
        inf["RelaxedStructure:gamma"] = gamma


    if "OUTCAR" in inf.keys():
        if print_level:
            print()
            print("OUTCAR:")
    
        outcarinf = inf.get("OUTCAR", None)
        if outcarinf is None:
            return inf, inf_meta, inf_compare, inf_func

        ISPIN = outcarinf["ISPIN"]
        EF = outcarinf.get("EF", "")
        TOTEN = outcarinf.get("TOTEN", "")
#        final_charges = outcarinf.get("Final_charges", "")
        for key in ['EF', "TOTEN"]:
            inf_compare[key] = locals()[key] 
            inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        born_charges = outcarinf.get("born_charges", None)
        if born_charges:
            for iion, bc in enumerate(born_charges):
                for i in range(3):
                    for j in range(3):
                        key = f"born_charges[{iion}][{i}][{j}]"
                        inf_compare[key] = bc[i][j]
                        inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        piezo_static_localeff = outcarinf.get("piezo_static_localeff", None)
        if piezo_static_localeff:
            idx_str = ["XX", "YY", "ZZ", "XY", "YZ", "ZX"]
            for i in range(3):
                for j in range(6):
                    key = f"piezo_static[{i}][{idx_str[j]}]"
                    inf_compare[key] = piezo_static_localeff[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        eps_static_localeff = outcarinf.get("eps_static_localeff", None)
        if eps_static_localeff:
            for i in range(3):
                for j in range(3):
                    key = f"eps_static_e[{i}][{j}]"
                    inf_compare[key] = eps_static_localeff[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        eps_static_ionic = outcarinf.get("eps_static_ionic", None)
        if eps_static_ionic:
            for i in range(3):
                for j in range(3):
                    key = f"eps_static_ionic[{i}][{j}]"
                    inf_compare[key] = eps_static_ionic[i][j]
                    inf_func[key] = lambda x1, x2, eps = 1.0e-4: check_difference(x1, x2, eps)

        if print_level:
            print(f"ISPIN: {ISPIN}")
            print(f"EF: {EF} eV")
            print(f"TOTEN: {TOTEN} eV")

    if "EIGENVAL" in inf.keys():
        if print_level:
            print()
            print("EIGENVAL:")
        eigenvalinf = inf.get("EIGENVAL", None)
        if eigenvalinf:
            nk      = eigenvalinf["nk"]
            nLevels = eigenvalinf["nLevels"]

            if print_level:
                print("k points in EIGENVAL:")
                print("nk=", nk)
                print("nLevels=", nLevels)

            bandedgeinf = vasp.find_band_edges_from_eigenval(EF0 = EF, eigenvalinf = eigenvalinf, ISPIN = ISPIN, print_level = print_level)
            EV    = bandedgeinf["EV"]
            EC    = bandedgeinf["EC"]
            Eg    = bandedgeinf["Eg"]
            for key in ["EV", "EC", "Eg"]:
                inf_compare[key] = locals()[key] 
                inf_func[key] = lambda x1, x2, eps = 1.0e-6: check_difference(x1, x2, eps)
            
            inf["EIGENVAL:EV"] = EV
            inf["EIGENVAL:EC"] = EC
            inf["EIGENVAL:Eg"] = Eg

            if print_level:
                print(f"EV: {EV} eV")
                print(f"EC: {EC} eV")
                print(f"Eg: {Eg} eV")

    """
    print()
    print("EIGENVAL_OPT:")
    eigenvaloptinf = inf["EIGENVAL_OPT"]
    nk      = eigenvaloptinf.get("nk", "")
    nLevels = eigenvaloptinf.get("nLevels", "")
    print("k points in EIGENVAL_OPT:")
    print("nk=", nk)
    print("nLevels=", nLevels)
    """

    """
    if "DOSCAR" in inf.keys():
        if print_level:
            print()
            print("DOSCAR:")

        dosinf = inf["DOSCAR"]
        E_raw   = dosinf["E"]
        dos_raw = dosinf["TotalDOS"]
        nDOS    = dosinf["nE"]
        Emin  = E_raw[0]
        Emax  = E_raw[nDOS-1]
        Estep = (E_raw[nDOS-1] - E_raw[0]) / (nDOS - 1)

        if print_level:
            print("  DOS E range: {} - {}, {} eV step".format(Emin, Emax, Estep))
    """

    return inf, inf_meta, inf_compare, inf_func
