import os
import sys
try:
    from dotenv import load_dotenv
except:
    print("\nWarning: Could not import dotenv")
    input("Install: pip install dotenv")

try:
    import tklib.tkimport as imp
except Exception as e:
    print()
    print("######################################################################")
    print("###########  ERROR ERROR ERROR ERROR ERROR ERROR #####################")
    print("######################################################################")
    print(f"# Failed to import [tklib.tkimport] module ({e}).")
    print(f"#  Add [tkProg]{os.sep}tklib{os.sep}python to PYTHONPATH variable")
    print(f"#  Current PYTHONPATH:", sys.path)
    print("######################################################################")
    input("Press ENTER to terminate>>")
    exit()

np   = imp.import_lib("numpy",      stop_by_error = False)
pd   = imp.import_lib("pandas",     stop_by_error = False)
pmg  = imp.import_lib("pymatgen",   stop_by_error = False)
json = imp.import_lib("json",       stop_by_error = False)
mjs  = imp.import_lib("monty.json", stop_by_error = False)
imp.messages(stop_by_error = True)

from numpy import sqrt, exp, log, log10
from matplotlib import pyplot as plt
from monty.json import MontyEncoder

from mp_api.client import MPRester
from pymatgen.core.structure import Structure
from pymatgen.core import Composition
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.io.cif import CifWriter, CifParser

from pymatgen.electronic_structure.core import Spin 
from emmet.core.electronic_structure import BSPathType

from pymatgen.electronic_structure.plotter import BSPlotter, DosPlotter
from pymatgen.phonon.plotter import PhononBSPlotter, PhononDosPlotter


from tklib.tkutils import print_data, pint, pfloat
from tklib.tkvariousdata import tkVariousData
from tklib.tkapplication import tkApplication
from tklib.tkparams import tkParams


PROGRAM_NAME = "Search Materials Project App"

script_path = os.path.abspath(__file__)
config_path = "translate.env"

print()
if load_dotenv:
    if os.path.isfile(config_path):
        print(f"config_path: {config_path}")
    else:
        print(f"Warning: config_path {config_path} is not found")
    load_dotenv(dotenv_path=config_path)

    account_inf_path = os.getenv("account_inf_path", "accounts.env")
    if os.path.isfile(account_inf_path):
        print(f"account_inf_path: {account_inf_path}")
    else:
        print(f"Warning: account_inf_path {account_inf_path} is not found")
    load_dotenv(dotenv_path=account_inf_path)

MP_APIKEY = os.getenv("MP_APIKEY")
if not MP_APIKEY:
    print("\nWarning: OPENAI_API_KEY environment variable is not set.")


def initialize(app):
    app.cfg = tkParams()

    app.cfg.figsize = [8, 6]

    arg_config_file = app.replace_path(None, template = ["{dirname}", "{filebody}_arg_config.xlsx"])

    print(f"Read config file [{arg_config_file}]")
    ret = app.read_arg_config_from_file(arg_config_file)
    if not arg_config_file:
        app.terminate("\nError in initialize(): Can not read [{arg_config_file}]")

    return app.cfg


def get_MPRester(API_KEY = None):
    if API_KEY is None or API_KEY == '': API_KEY = os.getenv('MP_APIKEY')
    if API_KEY is None:
        print("\nError: Can not get MP API Key from the environment var MP_APIKEY\n")
        return None

    mpr = MPRester(API_KEY)
    if mpr is None:
        print(f"\nError: Can not get MPRester using the given API_KEY [{API_KEY}]\n")
        return None

    return mpr

def get_material_data(formula, chemsys, material_ids, space_group, fuse_arg_line, arg_line, mpr):
    elements_in = None

    kwargs = {}
    if space_group is not None and space_group != '':
        try:
            kwargs["spacegroup_number"] = int(space_group)
        except:
            kwargs["spacegroup_symbol"] = space_group

    search_results = None
    if fuse_arg_line:
        cmd = f"mpr.materials.summary.search({arg_line})"
        print(f"Use User-specified arg_line: [{arg_line}]")
        print(f"  evaluate in get_material_data(): [{cmd}]")
        search_results = eval(cmd)
#        search_results = mpr.materials.summary.search(band_gap = (3.0, None), is_stable = True)
    elif material_ids != "":
        mid_list = material_ids.replae(' ', '').split(',')
        print(f"Search for material_ids=", mid_list)
        search_results = mpr.materials.summary.search(material_ids = mid_list, fields = [], **kwargs)
    elif formula[:3] == 'mp-':
        mid_list = formula.replae(' ', '').split(',')
        print(f"Search for material_ids=", mid_list)
        search_results = mpr.materials.summary.search(material_ids = mid_list, fields = [], **kwargs)
    elif formula == '' and chemsys != '':
        print(f"Search for chemsys={chemsys}")
        search_results = mpr.materials.summary.search(chemsys = chemsys, fields = [], **kwargs)
    elif chemsys != '' and formula != '':
        print(f"Search for chemsys={chemsys} and formula={formula}")
        search_results = mpr.materials.summary.search(chemsys = chemsys, formula = formula, fields = [], **kwargs)
    elif ',' in formula:
        elements_in = formula.replace(' ', '').split(',')
        print(f"Search for elements=", elements_in)
        search_results = mpr.materials.summary.search(elements = elements_in, fields = [], **kwargs)
    elif '-' in formula:
        print(f"Search for chemsys={formula}")
        search_results = mpr.materials.summary.search(chemsys = formula, fields = [], **kwargs)
    else:
        print(f"Search for formula={formula}")
        search_results = mpr.materials.summary.search(formula = formula, fields = [], **kwargs)

    if not search_results:
        if fuse_arg_line:
            print(f"No data found for {cmd}")
        else:
            print(f"No data found for {formula}")
        return None, None, None

    mid_list = []
    material_list = []
    elements_list = []
    for res in search_results:
        material_list.append(res)
        mid_list.append(str(res.material_id))
        elements_list.append(elements_in)

    return mid_list, material_list, elements_list

def get_structure(material_id, mpr, formula, dE, elements_in):
    available_fields = mpr.materials.summary.available_fields

    if dE is not None and dE != "":
        material_data = mpr.materials.summary.search(material_ids = [material_id], energy_above_hull = (0.0, pfloat(dE, defval = None)), fields = available_fields)
    else:
        material_data = mpr.materials.summary.search(material_ids = [material_id], fields = available_fields)

    if not material_data:
        print(f"No structure data found for material_id {material_id}")
        if dE is not None and dE != "":
            print(f"  Probably due to the limitation 0 <= dE <= {dE} eV")
        return None

    structure_dict = material_data[0].structure.as_dict()
    structure = Structure.from_dict(structure_dict)

    composition  = structure.composition
    cformula     = composition.reduced_formula
#    cformula     = structure.reduced_formula    #formula.replace(" ", "")
#    composition  = Composition(cformula)
    elements_res = [element.symbol for element in composition.elements]

    if elements_in:
        not_included = []
        for e in elements_res:
            if e not in elements_in: not_included.append(e)

        if len(not_included) >= 1:
            print(f"  ** One or more elements are not included for [{cformula}] in the given condition [{formula}]:", not_included)
            return None
        else:
            print(f"  ** Found [{cformula}] [{material_id}]")

    structure.material_id = material_id
    structure.cmaterial   = material_data[0]
    structure.cformula    = cformula
    structure.celements   = elements_res

    return structure

def save_structure(structure, symmetrize, outfile, idx, mpr, output_format = 'cif', prec = 1.0e-3, formula = None):
    keys = mpr.materials.summary.available_fields

#    print("structure=",structure)
#   print("prec=",prec)
#    symmetry_analyzer = SpacegroupAnalyzer(structure, symprec = prec) #, angle_tolerance = 5.0)
#    symmetrized_structure = symmetry_analyzer.get_symmetrized_structure()
#    print("symmetrized_structure:")
#    print(symmetrized_structure)

    if symmetrize:
        writer = CifWriter(structure, symprec = prec) #, angle_tolerance = 5.0)
    else:
        writer = CifWriter(structure)

    writer.write_file(outfile)

def save_meta_text(material, outfile, mpr):
    available_fields = mpr.materials.summary.available_fields

    fp = open(outfile, 'w')
    for key in available_fields:
        val = getattr(material, key, None)
        if val is not None:
            fp.write(f"{key}: {val}\n")
    fp.close()

def save_meta_json(material, outfile):
    with open(outfile, 'w') as f:
        json.dump(material.dict(), f, cls = MontyEncoder, indent = 2)

def get_band_data(mid, mpr, kpath):
    bs = None
#    print("kpath=", kpath)

    if kpath == "symline":
        try:
            bs = mpr.get_bandstructure_by_material_id(mid)
        except:
            pass
        if bs: return bs
        try:
            bs = mpr.get_bandstructure_by_material_id(mid, path_type = BSPathType.hinuma)
        except:
            pass
        if bs: return bs
        try:
            bs = mpr.get_bandstructure_by_material_id(mid, path_type = BSPathType.latimer_munro)
        except:
            pass
        if bs: return bs
    elif kpath == "Hinuma":
# -- line-mode, Hinuma et al.:
        try:
            bs = mpr.get_bandstructure_by_material_id(mid, path_type = BSPathType.hinuma)
        except:
            pass
    elif kpath == "Latimer-Munro":
# -- line-mode, Latimer-Munro:
        try:
            bs = mpr.get_bandstructure_by_material_id(mid, path_type = BSPathType.latimer_munro)
        except:
            pass
    elif kpath == "uniform":
# -- uniform:
        try:
            bs = mpr.get_bandstructure_by_material_id(mid, line_mode = False)                            
        except:
            pass
    else:
# -- line-mode, Setyawan-Curtarolo (default):
        try:
            bs = mpr.get_bandstructure_by_material_id(mid)
        except:
            pass

    if bs is None:
        print(f"  Band data is not available: Skip")
        return None
    
    return bs

def get_dos_data(mid, mpr):
    try:
        dos = mpr.get_dos_by_material_id(mid)
    except:
        print(f"   DOS data is not available: Skip")
        return None

    return dos

def save_band_csv(bs, outfile):
    if not bs: return False

    bs_data = []
    for band in bs.bands[Spin.up]:
        for kpoint, energy in zip(bs.kpoints, band):
            bs_data.append([kpoint.frac_coords, energy])

    bs_df = pd.DataFrame(bs_data, columns=['K-Point', 'Energy'])
    bs_df.to_csv(outfile, index = False)

    return True

def save_band_figure(bs, outfile):
    if not bs: return False
    
    try:
        bs_plotter = BSPlotter(bs)
        plt_bs = bs_plotter.get_plot()
        fig_bs = plt_bs.figure
        fig_bs.savefig(outfile)
    except:
        print(f"    Error in save_band_figure(): May be the given BandStructure object is not for SymmetryLine object\n")

    return True

def save_dos_csv(dos, outfile):
    if not dos: return False

    dos_data = []
    for energy, dos_value in zip(dos.energies, dos.densities[Spin.up]):
        dos_data.append([energy - dos.efermi, dos_value])

    dos_df = pd.DataFrame(dos_data, columns=['Energy', 'Density of States'])
    dos_df.to_csv(outfile, index = False)

    '''
    energies = dos.energies - dos.efermi
    total_dos = dos.densities

    element_dos = dos.get_element_dos()
    element_dos_dict = {}
    for element, pdos in element_dos.items():
        element_dos_dict[element] = pdos.get_densities()
    
    data = {"Energy (eV)": energies}
    data["Total DOS"] = total_dos
    for element, pdos in element_dos_dict.items():
        data[f"{element} DOS"] = pdos
    
    df = pd.DataFrame(data)
    df.to_csv(outfile, index = False)
    '''

    return True

def save_dos_figure(dos, outfile):
    if not dos: return False

    try:
        dos_plotter = DosPlotter()
        dos_plotter.add_dos("Total DOS", dos)
        dos_plotter.add_dos_dict(dos.get_element_dos())

        plt_dos = dos_plotter.get_plot()
        fig_dos = plt_dos.figure
        fig_dos.savefig(outfile)
    except:
        print(f"    Error in save_dos_figure(): Error occured to plot the given dos object\n")

    return True

def get_phonon_band_data(mid, mpr, kpath):
    ph_bs = mpr.get_phonon_bandstructure_by_material_id(mid)

    if ph_bs is None:
        print(f"  Phonon band data is not available: Skip")
        return None

    return ph_bs
    
def get_phonon_dos_data(mid, mpr):
    ph_dos = mpr.get_phonon_dos_by_material_id(mid)
    if ph_dos is None:
        print(f"  Phonon DOS data is not available: Skip")
        return None

    return ph_dos

def save_phonon_band_json(ph_bs, outfile):
    if not ph_bs: return False

    with open(outfile, 'w') as f:
        json.dump(ph_bs.as_dict(), f, cls = MontyEncoder, indent = 2)

    return True

def save_phonon_band_csv(ph_bs, outfile):
    if not ph_bs: return False

    phonon_bs_data = []
    for band in ph_bs.bands:
        for qpoint, frequency in zip(ph_bs.qpoints, band):
            phonon_bs_data.append([qpoint.frac_coords, frequency])

    phonon_bs_df = pd.DataFrame(phonon_bs_data, columns=['Q-Point', 'Frequency'])
    phonon_bs_df.to_csv(outfile, index=False)

    return True

def save_phonon_band_figure(ph_bs, outfile):
    if not ph_bs: return False

    phonon_bs_plotter = PhononBSPlotter(ph_bs)
    plt_phonon_bs = phonon_bs_plotter.get_plot()
    fig_phonon_bs = plt_phonon_bs.figure
    fig_phonon_bs.savefig(outfile)

    return True

def save_phonon_dos_json(ph_dos, outfile):
    if not ph_dos: return False

    with open(outfile, 'w') as f:
        json.dump(ph_dos.as_dict(), f, cls = MontyEncoder, indent = 2)

    return True

    dos_data = []
    for energy, dos_value in zip(dos.energies, dos.densities[Spin.up]):
        dos_data.append([energy - dos.efermi, dos_value])

    dos_df = pd.DataFrame(dos_data, columns=['Energy', 'Density of States'])
    dos_df.to_csv(outfile, index = False)

def save_phonon_dos_csv(ph_dos, outfile):
    if not ph_dos: return False

    phonon_dos_data = []
    for frequency, density in zip(ph_dos.frequencies, ph_dos.densities):
        phonon_dos_data.append([frequency, density])

    phonon_dos_df = pd.DataFrame(phonon_dos_data, columns=['Frequency', 'Density of States'])
    phonon_dos_df.to_csv(outfile, index = False)

    return True

def save_phonon_dos_figure(ph_dos, outfile):
    if not ph_dos: return False

    phonon_dos_plotter = PhononDosPlotter()
    phonon_dos_plotter.add_dos("Total DOS", ph_dos)
#    phonon_dos_plotter.add_dos_dict(ph_dos.get_element_dos())

    plt_phonon_dos = phonon_dos_plotter.get_plot()
    fig_phonon_dos = plt_phonon_dos.figure
    fig_phonon_dos.savefig(outfile)

    return True

def search_structure(app, cfg, mpr):
    mid_list, material_list, elements_list = get_material_data(cfg.formula, cfg.chemsys, cfg.material_ids, cfg.space_group,
                    cfg.fuse_arg_line, cfg.arg_line, mpr)
    if mid_list is None: return None

    print()
    print("Material IDs:", mid_list)

    nmater = len(mid_list)
    structures = []
    if cfg.nmaxdl >= 0 and nmater > cfg.nmaxdl:
        print(f"{cfg.nmaxdl} data out of the found {nmater} data will be downloaded")
    else:
        print(f"{nmater} data will be downloaded")
    for idx in range(nmater):
        if idx+1 > cfg.nmaxdl: break

        material_id = mid_list[idx]
        material    = material_list[idx]
        elements_in = elements_list[idx]

        print()
        structure = get_structure(material_id, mpr, cfg.formula, cfg.dE, elements_in)
        if structure is None: continue

        if cfg.fuse_arg_line:
            print(f"Found {structure.cformula} for material_id={material_id}: {cfg.arg_line}")
        else:
            print(f"Found {structure.cformula} for material_id={material_id}: {cfg.formula}/{cfg.chemsys}/{cfg.material_ids}")
        keys = ["last_updated", "deprecated", "energy_above_hull", "theoretical", "is_metal", "band_gap", "is_gap_direct", "database_IDs"]
        for key in keys:
            val = getattr(material, key, None)
            if val is not None: print(f"  {key}: {val}")

        '''
        task_ids = getattr(material, 'task_ids', None)
#        print("task_ids=", task_ids)
        for iid, tid in enumerate(task_ids):
            print(f"  task_id #{iid}: {str(tid)}")
            m = mpr.materials.summary.search(material_ids = [tid])
            if m is None or len(m) == 0: continue

#            print("     has_props=", m[0].has_props)
            input = getattr(m[0], 'input', None)
#            print("input=", input)
            if input:
                incar = input.get('incar', None)
#                print("incar=", incar)
                if incar:
                    GGA = incar.get('GGA', '')
                    metaGGA = incar.get('METAGGA', '')
                    LHFCALC = incar.get('LHFCALC', '')
                    print(f"    Functional: GGA={GGA} metaGGA={METAGGA} LHFCALC={LHFCALC}")
        '''

        cformula = structure.cformula
        if cfg.output_format == "poscar":
            if idx == 0:
                output_path = f"POSCAR"
            else:
                output_path = f"{cformula}_{structure.material_id}.POSCAR"
        else:
            output_path = f"{cformula}_{structure.material_id}.{cfg.output_format}"

        if cfg.fstructure:
            outfile = os.path.join(cfg.base_dir, output_path)
            print(f"  Save {cformula} structure to [{outfile}]")
            save_structure(structure, cfg.fsymmetrize, outfile, idx, mpr, cfg.output_format, cfg.prec, cfg.formula)

        if cfg.fmeta:
            outfile = os.path.join(cfg.base_dir, f'meta_{cformula}_{structure.material_id}.txt')
            print(f"  Save meta data to [{outfile}]")
            save_meta_text(structure.cmaterial, outfile, mpr)

            outfile = os.path.join(cfg.base_dir, f'meta_{cformula}_{structure.material_id}.json')
            print(f"  Save meta data to [{outfile}]")
            save_meta_json(structure.cmaterial, outfile)

        if cfg.fband:
            print("  Get band data...")
            bs = get_band_data(material_id, mpr, cfg.kpath)

            if bs:
                outfile = os.path.join(cfg.base_dir, f'band_{cformula}_{structure.material_id}.csv')
                print(f"  Save band data to [{outfile}]")
                save_band_csv(bs, outfile)
            
                outfile = os.path.join(cfg.base_dir, f'band_{cformula}_{structure.material_id}.png')
                print(f"  Save band figure to [{outfile}]")
                save_band_figure(bs, outfile)

            print("  Get DOS data...")
            dos = get_dos_data(material_id, mpr)

            if dos:
                outfile = os.path.join(cfg.base_dir, f'dos_{cformula}_{structure.material_id}.csv')
                print(f"  Save dos data to [{outfile}]")
                save_dos_csv(dos, outfile)

                outfile = os.path.join(cfg.base_dir, f'dos_{cformula}_{structure.material_id}.png')
                print(f"  Save dos figure to [{outfile}]")
                save_dos_figure(dos, outfile)

        if cfg.fphonon:
            print("  Get phonon band data...")
            bs = get_phonon_band_data(material_id, mpr, cfg.kpath)

            if bs:
                outfile = os.path.join(cfg.base_dir, f'phonon_band_{cformula}_{structure.material_id}.json')
                print(f"  Save phonon band data to [{outfile}]")
                save_phonon_band_json(bs, outfile)

                outfile = os.path.join(cfg.base_dir, f'phonon_band_{cformula}_{structure.material_id}.csv')
                print(f"  Save phonon band data to [{outfile}]")
                save_phonon_band_csv(bs, outfile)

                outfile = os.path.join(cfg.base_dir, f'phonon_band_{cformula}_{structure.material_id}.png')
                print(f"  Save phonon band figure to [{outfile}]")
                save_phonon_band_figure(bs, outfile)
    
            print("  Get phonon dos data...")
            dos = get_phonon_dos_data(material_id, mpr)

            if dos:
                outfile = os.path.join(cfg.base_dir, f'phonon_dos_{cformula}_{structure.material_id}.json')
                print(f"  Save phonon dos data to [{outfile}]")
                save_phonon_dos_json(dos, outfile)

                outfile = os.path.join(cfg.base_dir, f'phonon_dos_{cformula}_{structure.material_id}.csv')
                print(f"  Save phonon dos data to [{outfile}]")
                save_phonon_dos_csv(dos, outfile)

                outfile = os.path.join(cfg.base_dir, f'phonon_dos_{cformula}_{structure.material_id}.png')
                print(f"  Save phonon dos figure to [{outfile}]")
                save_phonon_dos_figure(dos, outfile)

        structures.append(structure)

    return structures

def main():
    app = tkApplication()
    cfg = initialize(app)

    cfg.base_dir = app.check_arg('--base_dir', defval = '.', vartype = 'str')
    cfg.formula = app.check_arg('--formula', defval = None, vartype = 'str')

#cfg.formulaからログファイル名を作り、console出力をredirectする
    cfg.logfile = app.replace_path(None, template = ["{base_dir}", "{formula}-out.txt"], ext_dict = {"formula": cfg.formula, "base_dir": cfg.base_dir})
    app.redirect(heading = f"Open logfile [{cfg.logfile}]",
                targets = ["stdout", cfg.logfile], mode = 'w')

#起動時引数で与えられたパラメータをcfgに設定
    print()
    print("Update parameters from command line arguments:")
    app.update_vars(cfg, apply_default = True)
    app.cfg.print_parameters()

    if cfg.MP_APIKEY is None or cfg.MP_APIKEY == "":
        cfg.MP_APIKEY = MP_APIKEY

    mpr = get_MPRester(cfg.MP_APIKEY)
    if not mpr: app.terminate(f"Error in main(): Can not get MPRester", pause = cfg.pause)

    available_fields = mpr.materials.summary.available_fields
    print()
    print("available_fields:")
    for f in sorted(available_fields):
        print("  ", f)

    if cfg.mode == 'structure':
        search_structure(app, cfg, mpr)
    elif cfg.mode == 'band':
        get_band_data(app, cfg.kpath)
    else:
        app.terminate(message = "\n", 
                        usage = app.usage,
                        post_message = f"Error in main: Invalide mode [{cfg.mode}]", 
                        pause = cfg.pause,
                     )

    app.terminate(pause = cfg.pause)


if __name__ == "__main__":
    main()

