import os
import sys
import re
import json
import openpyxl
import pymatgen
from pymatgen.io.vasp import Poscar


from tklib.tkapplication import tkApplication
from tklib.tkinifile import tkIniFile


infile = 'defect_energy_summary.json'
poscar_path = 'POSCAR'
poscar_defect_path = 'POSCAR'

outfile = os.path.splitext(infile)[0] + '.xlsx'
logfile = None
summaryfile = None

dS_default = 0.0


if len(sys.argv) > 1:
    infile = sys.argv[1]
if len(sys.argv) > 2:
   outfile = sys.argv[2]
if len(sys.argv) > 3:
    poscar_path = sys.argv[3]
if len(sys.argv) > 4:
    poscar_defect_path = sys.argv[4]


app = tkApplication()


def print_json_recursive(data, indent=0):
    if isinstance(data, dict):
        for key, value in data.items():
            if isinstance(value, dict):
                print('  ' * indent + f'{key}:')
                print_json_recursive(value, indent + 1)
            if isinstance(value, list) or isinstance(value, tuple):
                print('  ' * indent + f'{key}:')
                print_json_recursive(value, indent + 1)
            else:
                print('  ' * indent + f'{key}: {value}')
    else:
        for value in data:
            if isinstance(value, dict):
                print('  ' * indent + f':')
                print_json_recursive(value, indent + 1)
            if isinstance(value, list) or isinstance(value, tuple):
                print('  ' * indent + f':')
                print_json_recursive(value, indent + 1)
            else:
                print('  ' * indent + f': {value}')
    
def print_json(infile):
    with open(infile, 'r') as f:
        data = json.load(f)

    #dataの子要素を再帰的に表示
    print_json_recursive(data, indent = 2)

def to_input_xlsx(infile, outfile):
    global logfile, summaryfile

    print()
    print("#=====================================================")
    print("# Convert pydefect json to input.xlsx file")
    print("#=====================================================")
    print(f"Input file: {infile}")
    print(f"Output file: {outfile}")
    print(f"POSCAR file (idea)  : {poscar_path}")
    print(f"POSCAR file (defect): {poscar_defect_path}")

    print()
    print(f"Read ideal model POSCAR: {poscar_path}")
    poscar = Poscar.from_file(poscar_path, check_for_POTCAR=False, read_velocities=False)
    structure = poscar.structure
    #structureから単位格子体積を取得
    volume = structure.volume
#    print("structure=", structure)
#    composition = structure.composition
 #   c_dict = composition.as_dict()
#    print("composition=", c_dict)
    #strucutreからサイト情報を取得
    sites = structure.sites
#    print("sites=", sites)
    #それぞれのsitesの元素種と座標を取得
    atom_sites = [site.species_string for site in sites]
    atom_coords = [site.coords for site in sites]
#    print("atom_sites=", atom_sites)
#    print("atom_coords=", atom_coords)
    nsites = {}
    print("Atom sites in ideal model:")
    for i, atom in enumerate(atom_sites):
        print(f"  {atom}:", atom_coords[i]) 
        if atom not in nsites:
            nsites[atom] = 0
        nsites[atom] += 1

    print()
    print(f"Read defect model POSCAR: {poscar_path}")
    poscar_defect = Poscar.from_file(poscar_defect_path, check_for_POTCAR = False, read_velocities = False)
    structure_defect   = poscar_defect.structure
    volume_defect = structure_defect.volume
    sites_defect       = structure_defect.sites
    atom_sites_defect  = [site.species_string for site in sites_defect]
    atom_coords_defect = [site.coords for site in sites_defect]
    nsites_defect = {}
    print("Atom sites in defect model:")
    for i, atom in enumerate(atom_sites_defect):
        print(f"  {atom}:", atom_coords_defect[i]) 
        if atom not in nsites_defect:
            nsites_defect[atom] = 0
        nsites_defect[atom] += 1

    print()
    print("Unit cell volume of idetal model:", volume)    
    print("Unit cell volume of defect model:", volume_defect)
    print("Number of sites in idetal model:", nsites)    
    print("Number of sites in defect model:", nsites_defect)
    kVolume = int(volume_defect / volume + 0.05)
    if abs(kVolume * (volume / volume_defect) - 1.0) > 0.05:
            print(f"\nError in json2input_xlsx.to_input_xlsx():")
            print(f"    Inconsistent volumes between the ideal and the defect cells.")
            print(f"    volume_defect={volume_defect:.4f} / volume={volume:.4f} = {volume_defect / volume:.4f} must be an integer within error of 5%")
            app.terminate(f"*** Check if you specify CONTCAR in structure_opt, or if you used structure_opt/CONTCAR for defect supercells", pause = True)

    print(f"Number of sites are taken from the ideal model but the unit cell is expaned by a factor of {kVolume}")
    
    print()
    print(f"Read pydefect json file: {infile}")
    with open(infile, 'r') as f:
        data = json.load(f)
    #dataをdictに変換
    data_dict = dict(data)    
#    print("keys=", data_dict.keys())

    CBM = data_dict["cbm"]
    supercell_VBM = data_dict["supercell_vbm"]
    supercell_CBM = data_dict["supercell_cbm"]
    defects_dict = data_dict["defect_energies"]
    chemical_potentials = data_dict["rel_chem_pots"]
    chemical_potential_names = list(chemical_potentials.keys())
    
    print()
    print(f"CBM (host): {CBM} eV")
    print(f"CBM (supercell): {supercell_CBM} eV")
    print(f"VBM (supercell): {supercell_VBM} eV")

    print()
    print("Chemical potentials:")
    cp_dict = {}
    for i, cp_name in enumerate(chemical_potential_names):
        inf = chemical_potentials[cp_name]
        print(f"{cp_name}:")
        cp_dict[cp_name] = {}
        for atom_name, cp in inf.items():
            print(f"  {atom_name}: {cp} eV")
            cp_dict[cp_name][atom_name] = cp

    atoms = set()
    print()
    print("Defects:")
    idx_defect = 0
    for defect_name, inf in defects_dict.items():
        atom_name, site = defect_name.split("_")
        if atom_name == 'Va':
            atom_name = 'V'
        else:
            atoms.add(atom_name)
        print(f"{atom_name}{site}: ", inf.keys())

        atom_io = inf["atom_io"]
        for atom_name, delta_n in atom_io.items():
            print(f"  {atom_name}: {delta_n}")
            atoms.add(atom_name)

        charges = inf["charges"]
        print(f"  charges: {charges}")

        energies = inf["defect_energies"]
        for iq, dinf in enumerate(energies):
            charge = charges[iq]
            dH0 = dinf["formation_energy"]
            is_shallow = dinf["is_shallow"]
            dEcharge = dinf["energy_corrections"]["pc term"]
            dEVBM = dinf["energy_corrections"]["alignment term"]
            print(f"  q={charge:3}: dH0={dH0:10.4g} eV dEcharge={dEcharge:10.4g} eV  dEVBM={dEVBM:10.4g} eV  (is_shallow={is_shallow})")
    
    wb = openpyxl.Workbook()
    ws = wb.active
    ws.cell(row=1, column=1, value='Version2')
    ws.cell(row=1, column=2, value='Atom')
    ws.cell(row=1, column=3, value='Site')
    ws.cell(row=1, column=4, value='q')
    ws.cell(row=1, column=5, value='dS/kB')
    ws.cell(row=1, column=6, value='N0')
    ws.cell(row=1, column=7, value='Ndoped')

    cp_name_offset = 8
    for i, cp_name in enumerate(cp_dict.keys()):
        ws.cell(row = 1, column = 7 + i + 1, value = cp_name)
        for atom_name in cp_dict[cp_name].keys():
            cp = cp_dict[cp_name][atom_name]
            ws.cell(row = 1, column = i + cp_name_offset, value = cp_name)

    dn_offset = len(cp_dict) + cp_name_offset + 1
    for i, atom_name in enumerate(atoms):
        ws.cell(row = 1, column = i + dn_offset, value = f"dn({atom_name})")

    meta_offset = dn_offset + len(atoms) + 1
    ws.cell(row = 1, column = meta_offset + 0, value = "dH0")
    ws.cell(row = 1, column = meta_offset + 1, value = "dEVBM")
    ws.cell(row = 1, column = meta_offset + 2, value = "dEpc")

    mu_offset = meta_offset + 3
    for icp, cp_name in enumerate(cp_dict.keys()):
        for iatom, atom in enumerate(atoms):
            cp = cp_dict[cp_name][atom]
            ws.cell(row = 1, column = mu_offset, value = f"mu({atom})@{cp_name}")
            mu_offset += 1

    print()
    print("Writing...")
    line = 2
    print("nsites in idetal model:", nsites)
    print(f"   kVolume={kVolume:.4f}")
    print("nsites in defect model:", nsites_defect)
    for defect_name, inf in defects_dict.items():
        atom_name, site = defect_name.split("_")
        atom_io = inf["atom_io"]
        energies = inf["defect_energies"]
        if atom_name == 'Va':
            atom_name = 'V'
        
        m = re.match(r"([A-Z][a-z]?)(\d*)", site)
        if m:
            g = m.groups()
            site_name = g[0]
        else:
            site_name = site

        print(f"defect name: {defect_name}  atom name: {atom_name}  site: {site}  site_name: {site_name}")
        if site == 'i':
            _nsites = 1
        elif site in nsites:
            _nsites = nsites[site] * kVolume
        elif site in nsites_defect:
            _nsites = nsites_defect[site]
        elif site_name in nsites:
            _nsites = nsites[site_name] * kVolume
        elif site_name in nsites_defect:
            _nsites = nsites_defect[site_name]
        else:
            app.terminate(f"\nError in json2input_xlsx.to_input_xlsx(): Can not find nsite from site [{site}]\n", pause = True)

            '''
            #reモジュールを使って、siteの最後の数字を全て削除する
            _site = re.sub(r'\d+$', '', site)
            if _site in nsites:
                _nsites = nsites[_site] * kVolume
            elif _site in nsites_defect:
                _nsites = nsites_defect[_site]
            else:
                _nsites = None
            '''

        print(f"{atom_name}@{site}: nsite={_nsites}")

        charges = inf["charges"]
        print("  charges:", charges)
        print("  atom_io:", atom_io)
        print("  atoms  :", atoms)
        for iq, dinf in enumerate(energies):
            charge = charges[iq]
            dH0 = dinf["formation_energy"]
            is_shallow = dinf["is_shallow"]
            dEcharge = dinf["energy_corrections"]["pc term"]
            dEVBM = dinf["energy_corrections"]["alignment term"]

            print(f"  q={charge}:")
            ws.cell(row = line, column = 2, value = atom_name)
            ws.cell(row = line, column = 3, value = site)
            ws.cell(row = line, column = 4, value = charge)
            ws.cell(row = line, column = 5, value = dS_default)
            ws.cell(row = line, column = 6, value = _nsites)
            ws.cell(row = line, column = 7, value = 0)

            print(f"    chemical potentials: {atom_name}@{site} q={charge}")
            for icp, cp_name in enumerate(cp_dict.keys()):
                print(f"      {cp_name}:", cp_dict[cp_name])
                dEmu = 0.0
                for iatom, atom in enumerate(atoms):
                    delta_n = atom_io.get(atom, 0)
                    cp = cp_dict[cp_name][atom]
                    print(f"      for {atom}: dEmu = {dEmu} + {cp} * {delta_n}")   
                    dEmu += cp * delta_n
                print(f"        dEmu = {dEmu}")
                dH = dH0 + dEVBM + dEcharge - dEmu
                print(f"        dH:{dH:10.6g} eV = dH0:{dH0:10.6g} + dEVBM:{dEVBM:10.6g} + dEpc:{dEcharge:10.6g} - dEmu{dEmu:10.6g}")
                ws.cell(row = line, column = icp + 8, value = dH)

            for iatom, atom in enumerate(atoms):
                delta_n = atom_io.get(atom, 0)
                cp = cp_dict[cp_name][atom]
                ws.cell(row = line, column = iatom + dn_offset, value = delta_n)

            ws.cell(row = line, column = meta_offset + 0, value = dH0)
            ws.cell(row = line, column = meta_offset + 1, value = dEVBM)
            ws.cell(row = line, column = meta_offset + 2, value = dEcharge)

            mu_offset = meta_offset + 3
            for icp, cp_name in enumerate(cp_dict.keys()):
                for iatom, atom in enumerate(atoms):
                    cp = cp_dict[cp_name][atom]
                    ws.cell(row = line, column = mu_offset, value = cp)
                    mu_offset += 1

            line += 1

    print()
    print(f"Save to {outfile}")
    wb.save(outfile)

    print()
    print(f"Save to {summaryfile}")
    ini = tkIniFile()
    kwargs = {
        "defect_charge"  : charge,
        "dEBF"           : 0.0,
        "dEVBM"          : dEVBM / charge,
        "dEcorr"         : dEcharge,
        "Etot_ideal"     : 0.0,
        "Etot_defect"    : dH0,
        "VASP_ideal_Eg"  : CBM,
        "VASP_ideal_EVBM": 0.0,
        "VASP_ideal_ECBM": CBM,
        "VASP_defect_EVBM": supercell_VBM,
        "VASP_defect_ECBM": supercell_CBM,
        }
    ini.write_from_scratch(summaryfile, 'results', **kwargs)


    print()
    print("#=============================================================")
    print("#===========  NOTE !!! =======================================")
    print(f"# CONFIRM to check N0 in {outfile} ")
    print(f"#   They are taken from {poscar_path}, but those must be consist with")
    print(f"    the densities of the correspondding sites")
    print(f"    Particularly, the default N0 written are inaccurate")
    print(f"    if the crystal has two or more independent sites for same atom")
    print("#=============================================================")
    print("#=============================================================")
    print()

    input("\nPress ENTER to terminate>>\n")


def main():
    global logfile, summaryfile

    logfile     = app.replace_path(infile, template = ["{dirname}", "{filebody}-j2x-out.txt"])
    summaryfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-j2x-summary.prm"])

    print("")
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

#    print_json(infile)  
#    to_excel(infile, outfile)
    to_input_xlsx(infile, outfile)


if __name__ == '__main__':
    main()


