import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import tkinter as tk

#from pymatgen.core import Composition
from pymatgen.core import Lattice, Structure, Site


from tklib.tkapplication import tkApplication
from tklib.tkfile import tkFile
from tklib.tkutils import get_ext, getarg, getintarg, getfloatarg, pint, pfloat, terminate
from tklib.tkutils import split_file_path, replace_path
from tklib.tkwrapper.tktqdm import tqdm
from tklib.tkcrystal.tkcrystal import tkCrystal
from tklib.tkcrystal.tkpymatgen import tkPymatgenVASP, tkPymatgen
from tklib.tkcrystal.tkalamode import read_input, read_taudata, read_log, read_kpoints, read_properties, cal_contrib, convert_k_list, convert_k

from tklib.tkgraphic.tkplotevent import tkPlotEvent, RangeSelector
from tklib.tkgraphic.tkplot_pyplot import select_plt, tkPlot_pyplot
from tklib.tkgraphic.tkplot_tkinter import tkPlot_tkinter
from tklib.tkgui.tksimple_gui import tkWidgets, CustomDialog_with_config


"""
Convert ALAMODE .xyz to CIF
"""


plugin_ver = "structure:0.2"


#================================
# global parameters
#================================
default_ext = ".xyz"
input_type  = "ALAMODE .xyz"
output_type = "CIF"


#=============================
# Treat argments
#=============================
def initialize(app = None, cparams = None):
    cparams.debug = 0

    inputfile = 'anime.in'
    logfile   = 'k25.log'
    infile    = 'vmd_data_k25_k2/ik2.anime01.xyz'
    outfile   = 'structure.cif'

    symprec = 1.0e-5
    draw_range   = "2,2,2"
    max_distance = 2.75
    atom_alpha = 0.5

    displacement = 0.2
    sleep = 0.01
    time_step = 0.05
    nstep = 100

    kr           = 200.0
    bond_r       = 3.0
    vector_r      = 1.5
    vector_length = 2.0
    arrow_length_ratio = 0.3
    
    fontsize = 16
    legend_fontsize = 12

    app.add_argument(opt = '--inputfile', type = "str",   defval = inputfile)
    app.add_argument(opt = '--logfile',   type = "str",   defval = logfile)
    app.add_argument(opt = '--infile',    type = "str",   defval = infile)
    app.add_argument(opt = '--outfile',   type = "str",   defval = outfile)

    app.add_argument(opt = '--symprec', type = "float", defval = symprec)

    app.add_argument(opt = '--draw_range',    type = "str",   defval = draw_range)
    app.add_argument(opt = '--max_distance',  type = "float", defval = max_distance)
    app.add_argument(opt = '--atom_alpha',  type = "float", defval = atom_alpha)

    app.add_argument(opt = '--displacement', type = "float", defval = displacement)
    app.add_argument(opt = '--sleep',        type = "float", defval = sleep)
    app.add_argument(opt = '--time_step',    type = "float", defval = time_step)
    app.add_argument(opt = '--nstep',        type = "int", defval = nstep)

    app.add_argument(opt = '--kr',            type = "float", defval = kr)
    app.add_argument(opt = '--bond_r',        type = "float", defval = bond_r)

    app.add_argument(opt = '--vector_r',      type = "float", defval = vector_r)
    app.add_argument(opt = '--vector_length', type = "float", defval = vector_length)
    app.add_argument(opt = '--arrow_length_ratio', type = "float", defval = arrow_length_ratio)

    app.add_argument(opt = '--coordinate', type = "str", defval = 'fc')
    app.add_argument(opt = '--site_name',  type = "str", defval = 'as')

    app.add_argument(opt = '--fontsize', type = "int", defval = fontsize)
    app.add_argument(opt = '--legend_fontsize', type = "int", defval = legend_fontsize)



def get_base_path(infile):
    vasp = tkVASP()
    if os.path.isfile(infile):
        base_path = vasp.getdir(infile)
    else:
        base_path = infile

    return vasp, base_path

def check_exist(path, print_level = 0):
    f = os.path.isfile(path)
    if not f:
        if print_level:
            print(f"Error in check_file_type(): [{path}] does not exist")
        return None

    return True

def check_file_type(infile, inf = None, app = None, cparams = None, print_level = 0):
    if not os.path.isfile(infile) and not os.path.isdir(infile):
        return None  #f"Error: file [{infile}] does not exist"

    ext = get_ext(infile).lower()
    if ext == default_ext:
        return {"file_type": input_type}

    return {"file_type": "VASP POSCAR"}

def get_output_path(infile, inf = None, app = None, cparams = None):
    cparams.outfile = app.replace_path(infile, template = ["{dirname}", "{filebody}.cif"])
    return cparams.outfile

def get_input_type(inf = None, app = None, cparams = None):
    return {"file_type": input_type}

def get_output_type(inf = None, app = None, cparams = None):
    return {"file_type": output_type}

def read_data(infile, app = None, cparams = None, print_level = 0):
    infile = os.path.abspath(infile)

    if print_level >= 1:
        print("")
        print("ALAMODE input file path: ", cparams.inputfile)
        print("ALAMODE log file path  : ", cparams.logfile)
        print("ALAMODE .xyz file path : ", cparams.infile)

    dirname, basename, filebody, ext = split_file_path(infile, check_dir = True)
    sample_name = filebody

    pm  = tkPymatgen()

    print()
    print(f"Read log file [{cparams.inputfile}]")
    inf_input = read_input(cparams.inputfile, print_level = 1)
    if inf_input is None:
        if print_level:
            print()
            print(f"Error in alamode_xyz2cif.read_data(): Can not read input file [{cparams.inputfile}]")
            print()
        return None
#    print("inf_input=", inf_input)
    
    print()
    print(f"Read log file [{cparams.logfile}]")
    inf_log, kp_list, kp_list_ik = read_log(cparams.logfile)
    if inf_log is None:
        if print_level:
            print()
            print(f"Error in alamode_xyz2cif.read_data(): Can not read log file [{cparams.logfile}]")
            print()
        return None

    if inf_input["aij"]:
        aij_input = inf_input["aij"]
        aij  = inf_log["aij"]
        paij = inf_log["paij"]
        haij = inf_log.get("haij", None)
    else:
        latt = inf_input["latt"]
        aij  = Lattice.from_parameters(*latt)
        paij = None
        haij = None
    
    _cry = tkCrystal()
    latt_aij_input = _cry.calculate_lattice_parameters_from_vector(aij_input)
    latt_aij  = _cry.calculate_lattice_parameters_from_vector(aij)
    latt_paij = _cry.calculate_lattice_parameters_from_vector(paij)
    latt_haij = _cry.calculate_lattice_parameters_from_vector(haij)
    if print_level:
        print("input aij    =", aij_input)
        print("supercell aij=", aij)
        print("primitive cell aij=", paij)
        print("harmonic supercell aij=", haij)
        print("input lattice parameters             =", latt_aij_input)
        print("supercell lattice parameters         =", latt_aij)
        print("primitive cell lattice parameters    =", latt_paij)
        print("harmonic supercell lattice parameters=", latt_haij)
    
    print()
    print(f"Read xyz file [{cparams.infile}]")
    fin = tkFile(app.cparams.infile, "r")
    if fin is None or fin.fp is None:
        if print_level:
            print()
            print(f"Error in alamode_xyz2cif.read_data(): Can not read [{cparams.xyzfile}]")
            print()
        return None

    symprec = cparams.get("symprec", 1.0e-4)

    istructure = 0
    crystal_list = []
    structure_list = []
    while True:
        line = fin.readline()
        if not line: break

        cry = tkCrystal()
        cry.SetSampleName(sample_name)
        cry.SetCrystalName(sample_name)
        cry.set_lattice_vectors(aij)
        latt = cry.lattice_parameters()

        lattice_obj = Lattice.from_parameters(*latt)
        structure = Structure(lattice_obj, [], [])

        nsites = pint(line)

        line = fin.readline()
        _aa = line.split('(', 1)
        imode = pint(_aa[1])

        _aa2 = _aa[1].split(')', 1)
        _aa3 = _aa2[0].split()
        kx = pfloat(_aa3[0])
        ky = pfloat(_aa3[1])
        kz = pfloat(_aa3[2])

        _aa4 = _aa2[1].split('=')
        f = pfloat(_aa4[1], strict = False)
        istep = pint(_aa4[-1])
#        print("nsite=", nsites, imode)
#        print("k=", kx, ky, kz)
#        print("f=", f)
#        print("istep=", istep)

        print(f"istructure={istructure}")
        for i in tqdm(range(nsites)):
            line = fin.readline()
            if not line: break
            
            _aa = line.split()
            atom_name = _aa[0]
            x = float(_aa[1])
            y = float(_aa[2])
            z = float(_aa[3])

            cry.AddAtomSite(name = atom_name, pos = [x, y, z])

        cry.expand_coordinates()
        crystal_list.append(cry)

        atom_sites = cry.atom_site_list()

        for i in range(nsites):
            site = atom_sites[i]
            atom_name = site.atom_name_only()
            pos = site.position()
            structure.append(atom_name, pos, coords_are_cartesian = True)

        formula, reduced_formula = pm.chemical_formula(structure)
        iSPG, SPG_name, symmetrized_structure = pm.symmetrize(structure, symprec)

        structure_list.append(structure)

        istructure += 1

        if istructure >= 2: break

    nstructures = len(structure_list)
    if nstructures >= 2:
        if print_level:
            print()
            print("Calculate velocities")
        for istr, structure in enumerate(structure_list):
            if istr == 0:
                i = 0
                j = 1
                div = 1
            elif istr == nstructures - 1:
                i = istr - 1
                j = istr
                div = 1
            else:
                i = istr - 1
                j = istr + 1
                div = 2
            
            structure_i = structure_list[i]
            structure_j = structure_list[j]
            sites = structure.sites
            for isite in range(len(sites)):
                site   = structure.sites[isite]
                site_i = structure_i.sites[isite]
                site_j = structure_j.sites[isite]
                pos_i = site_i.coords
                pos_j = site_j.coords
                vx = (pos_j[0] - pos_i[0]) / div
                vy = (pos_j[1] - pos_i[1]) / div
                vz = (pos_j[2] - pos_i[2]) / div
#                print("v=", istr, vx, vy, vz)
                site.properties["velocity"] = [vx, vy, vz]

    if print_level > 1:
#        cry.print_inf()
        print("last structure:")
        print(structure)

    inf = {}
    inf["data_list_type"] = "[tkCIF]"
    inf["filename"]    = infile
    inf["filenames"]   = [infile]
    inf["sample_name"] = reduced_formula
    inf["meta"]        = [] #[inf_list]
    inf["pymatgen"]       = pm
    inf["crystal_list"]   = crystal_list
    inf["structure_list"] = structure_list

    inf["aij_input"] = aij_input
    inf["aij"]  = aij
    inf["paij"] = paij
    inf["haij"] = haij
    inf["lattice_parameter(aij_input)"]   = latt_aij_input
    inf["lattice_parameter(aij)"]   = latt_aij
    inf["lattice_parameter(paij)"]  = latt_paij
    inf["lattice_parameter(haij)"]  = latt_haij

    return inf

def print_data(inf, app = None, cparams = None, print_level = 1):
    for i, structure in enumerate(inf["structure_list"]):
        print()
        print(f"icrystal={i}")
#        cry.print_inf()
        print(structure)

def save_data(outfiles, inf, app = None, cparams = None, print_level = 0):
    if type(outfiles) == str:
        outfile = outfiles
    else:
        outfile = outfiles[0]

    structure = inf["structure_list"][0]
    pm = inf["pymatgen"]
    
    print()
    print(f"Save structure to [{outfile}]")
    pm.to(outfile, structure, symprec = cparams.symprec, print_level = print_level)

    return True

def draw(tkplt, inf, app = None, cparams = None, ax = None, animation = True, stop_callback = None, print_level = 0):
    max_distance = cparams.max_distance
    draw_range   = [pfloat(s) for s in cparams.draw_range.split(',')]
    print()
    print("Draw configuration:")
#    print(f"draw_range: ", draw_range)
    print(f"max_distance: {max_distance}")

    structure = inf["structure_list"][0]
    pm   = tkPymatgen()

    latt_aij_input = inf["lattice_parameter(aij_input)"]
    latt_aij = inf["lattice_parameter(aij_input)"]
#    latt_paij = inf["lattice_parameter(paij)"]
#    latt_haij = inf["lattice_parameter(haij)"]
#    nx = draw_range[0] * latt_aij_input[0] / latt_haij[0]
#    ny = draw_range[1] * latt_aij_input[1] / latt_haij[1]
#    nz = draw_range[2] * latt_aij_input[2] / latt_haij[2]
    nx, ny, nz = draw_range
    print(f"Input lattice parameters: ", latt_aij_input)
    print(f"Draw range: x in (0, {nx:.2f}) x in (0, {ny:.2f}) x in (0, {nz:.2f})")

    lattice_obj = Lattice.from_parameters(*latt_aij_input)
#    lattice_obj = Lattice.from_parameters(*latt_paij)
    structure_draw = Structure(lattice_obj, [], [])

    eps = 0.1
    sites = structure.sites
    i = 0
    for site in sites:
        x, y, z = site.coords
        if x < -eps or nx * latt_aij_input[0] + eps < x: continue
        if y < -eps or ny * latt_aij_input[1] + eps < y: continue
        if z < -eps or nz * latt_aij_input[2] + eps < z: continue

        atom_name = site.specie.symbol
        structure_draw.append(atom_name, site.coords, coords_are_cartesian = True)
        structure_draw.sites[i].properties = site.properties.copy()
        i += 1

    if print_level:
        print()
        print("Draw structre:")
        print(structure_draw)

    pm.draw_structure(tkplt, ax, structure_draw, draw_range = [[0, nx], [0, ny], [0, nz]],
                max_distance = max_distance, kr = cparams.kr, atom_alpha = cparams.atom_alpha, bond_r = cparams.bond_r,
                vector_r = cparams.vector_r, vector_length = cparams.vector_length, 
                arrow_length_ratio = cparams.arrow_length_ratio,
                displacement = cparams.displacement,
                animation = animation, stop_callback = stop_callback,
                sleep = cparams.sleep, time_step = cparams.time_step, nstep = cparams.nstep,
                fontsize = cparams.fontsize, legend_fontsize = cparams.legend_fontsize)
#    pm.draw_structure(ax, structure, draw_range = [[0, nx], [0, ny], [0, nz]], max_distance = max_distance)

#    plt.pause(1.0e-4)

def create_window(app, cparams, tkplt, fig, axes, inf):
    tkplt.create_window(fig)
    tkplt.add_toolbar()

    wg = tkWidgets(parent = tkplt.parent, plt = plt)
    wg.set_font(size = 14)
    notebook = wg.add_tab()
    page1_frame = wg.add_page(notebook, title = "main")
#    page2_frame = wg.add_page(notebook, title = "configure")

    top_frame = tk.Frame(page1_frame) #tkplt.parent)
    top_frame.pack(side = tk.TOP, expand = True, fill = "both")

    left_frame = tk.Frame(top_frame)
    left_frame.pack(side = tk.LEFT, anchor = "n", expand = True, fill = "x")

    canvas_frame = tk.Frame(top_frame)
    canvas = tkplt.add_canvas(fig, parent = canvas_frame)
    canvas_frame.pack(side = tk.LEFT, expand = True, fill = "both")

    def save_figure(fig, outfile = "figure.png"):
        print()
        print(f"Save figure to [{outfile}]")
        fig.savefig(outfile)

    def redraw():
        wg.update_variables()
        for key in cparams.__dict__.keys():
            if key in wg.vars.keys(): setattr(cparams, key, getattr(wg.vars, key))

        axes[0].clear()
        draw(tkplt, inf, app = app, cparams = cparams, ax = axes[0], 
                    animation = False, print_level = 1)
        tkplt.draw()

    def get_animation_stop(wg):
        return wg.stop_animation

    def animation_start():
        print()
        print("Start animation")

        wg.update_variables()
        for key in cparams.__dict__.keys():
            if key in wg.vars.keys(): setattr(cparams, key, getattr(wg.vars, key))

        wg.stop_animation = False
        axes[0].clear()
        draw(tkplt, inf, app = app, cparams = cparams, ax = axes[0], 
                    animation = True, stop_callback = lambda: get_animation_stop(wg), print_level = 1)
        tkplt.draw()

    def animation_stop():
        print()
        print("Stop animation")
        wg.stop_animation = True
        tkplt.draw()

    tkplt.add_toolbar_button(text = "Redraw", command = redraw)
    tkplt.add_toolbar_button(text = "Animation start", command = animation_start)
    tkplt.add_toolbar_button(text = "Animation stop", command = animation_stop)
    tkplt.add_toolbar_button(text = "Save figure", command = lambda: save_figure(fig))

    wg.vars.update(**cparams.__dict__)

    def on_click():
#        show_varibles(config)
        wg.update_variables()
        print("clicked", wg.vars.Vg_min, wg.vars.Vg_max, wg.vars.Vd_min, wg.vars.Vd_max)

#    button1 = tk.Button(left_frame, text = "test")#, command = command)
    widgets = [
            [{"type": "label", "label": "input file: "},
             {"type": "entry","varname": "inputfile", "vartype": "str", "width": 20, "state": "readonly", "expand": True, "fill": "x"}
            ],
            [{"type": "label", "label": "log file: "},
             {"type": "entry","varname": "logfile", "vartype": "str", "width": 20, "state": "readonly", "expand": True, "fill": "x"}
            ],
            [{"type": "label", "label": "xyz file: "},
             {"type": "entry","varname": "infile", "vartype": "str", "width": 20, "state": "readonly", "expand": True, "fill": "x"}
            ],
            
            [{"type": "label", "label": "Draw range:"},
             {"type": "entry","varname": "draw_range", "vartype": "str", "width": 20}
            ],

            [{"type": "label", "label": "Atom:"}], 
            [{"type": "label", "label": "  atomic radius coefficient:"}, 
             {"type": "entry","varname": "kr", "vartype": "float", "width": 10},
             {"type": "label", "label": "alpha:"}, 
             {"type": "entry","varname": "atom_alpha", "vartype": "float", "width": 10},
            ],
            
            [{"type": "label", "label": "Chemical bond:"}],
            [{"type": "label", "label": "  max bond length:"}, 
             {"type": "entry","varname": "max_distance", "vartype": "float", "width": 10}
            ],
            [{"type": "label", "label": "  radius:"}, 
             {"type": "entry","varname": "bond_r", "vartype": "float", "width": 10},
            ],

            [{"type": "label", "label": "Displacement:"}],
            [{"type": "label", "label": "  vector length:"}, 
             {"type": "entry","varname": "vector_length", "vartype": "float", "width": 10},
             {"type": "label", "label": "  radius:"}, 
             {"type": "entry","varname": "vector_r", "vartype": "float", "width": 10}
            ],
            [{"type": "label", "label": "  arrow head ratio:"}, 
             {"type": "entry","varname": "arrow_length_ratio", "vartype": "float", "width": 10}
            ],
            
            [{"type": "label", "label": "One shot draw:"},
             {"type": "button", "name": "redraw", "text": "redraw", "anchor": "w", "command": redraw},
            ],

            [{"type": "label", "label": "Animation:"},
             {"type": "button", "name": "animation_start", "text": "start", "anchor": "w", "command": animation_start},
             {"type": "button", "name": "animation_stop", "text": "stop", "anchor": "w", "command": animation_stop},
            ],

            [{"type": "label", "label": "  sleep:"},
             {"type": "entry","varname": "sleep", "vartype": "float", "width": 10},
             {"type": "label", "label": "time step:"},
             {"type": "entry","varname": "time_step", "vartype": "float", "width": 10},
             {"type": "label", "label": "nstep:"},
             {"type": "entry","varname": "nstep", "vartype": "int", "width": 10}
            ],
            ]


    widgets_frame, config_left_pain = wg.add_widgets(parent = left_frame, 
                        widgets = widgets, side = "top", expand = True, fill = "x")

    draw(tkplt, inf, app = app, cparams = cparams, ax = axes[0], animation = False, print_level = 1)

    return wg, widgets_frame, config_left_pain

def plot_data(inf, app = None, cparams = None, print_level = 0):
    tkplt, root, tkpyplot = select_plt(use_tkinter = 1, plt = plt, parent = None, title = "Phonon animation")
    fig, axes = tkplt.subplots(1, 1, figsize = (8, 8), dpi = 100, projection = '3d', tight_layout = False)
#    ax = fig.add_subplot(111, projection = '3d')
    ax = axes[0]
    plot_event = tkPlotEvent(plt)

    wg, widgets_frame, config_left_pain = create_window(app, cparams, tkplt, fig, axes, inf)

    tkplt.pause(1.0e-5)
    root.mainloop()
#    input("\nPress ENTER to terminate>>\n")

def convert(inf, app = None, cparams = None, print_level = 0):
    return inf


def main():
#==================================================================
# Initialize parameters
#==================================================================
    app     = tkApplication(globals = globals(), locals = locals())
#    app     = tkApplication(usage_str  = usage_str, globals = globals(), locals = locals())
    cparams = app.get_params()
    app.cparams = cparams

    initialize(app, cparams)
#    update_vars(app, cparams)
    app.update_vars(cparams, apply_default = True)

#    vasp, base_path = get_base_path(cparams.infile)
    base_path = os.path.dirname(cparams.infile)

    print("")
    print( "==========================================================================")
    print(" Convert ALAMODE .xyz file to CIF")
    print( "==========================================================================")
    print(f"ALAMODE input file: {cparams.inputfile}")
    print(f"ALAMODE log file  : {cparams.logfile}")
    print(f"ALAMODE .xyz file : {cparams.infile}")
    print(f"Output file       : {cparams.outfile}")
#    print(f"base_path : {base_path}")

#    logfile = os.path.join(base_path, "poscar2cif-out.txt")
    logfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    cparams.outfile = get_output_path(cparams.infile, app = app, cparams = cparams)
    print(f"Output CIF file : {cparams.outfile}")

    file_type = check_file_type(cparams.infile, print_level = 1)
    if file_type is None:
        app.terminate(f"Error: [{cparams.infile}] invalid file type", usage = "built-in")
    if 'Error' in file_type:
        app.terminate(file_type, usage = usage)

    inf = read_data(cparams.infile, app = app, cparams = cparams, print_level = 2)
#    print_data(inf)
#    inf = convert(inf, cparams = cparams, print_level = 1)
    save_data([cparams.outfile], inf, cparams = cparams, print_level = 9)
    plot_data(inf, cparams = cparams, print_level = 0)

    app.terminate(usage = 'built-in')


if __name__ == "__main__":
    main()
