import re
import sys

from pymatgen.io.cif import CifParser
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.core.periodic_table import Element


from tklib.tkapplication import tkApplication
from tklib.tkutils import terminate, pint, pfloat, getarg, getintarg, getfloatargfrom tklib.tkfile import tkFile
from tklib.tkcrystal.tkcif import tkCIF


infile = 'SrTiO3.cif'
prec = 1.0e-3

app    = tkApplication()
argv = sys.argv
narg = len(argv)
if narg >= 2:
    infile = argv[1]


#==========================================
# Main prgram
#==========================================
def main():
    logfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-out.txt"])
    print(f"Open logfile [{logfile}]")
    app.redirect(targets = ["stdout", logfile], mode = 'w')

    convCIFfile = app.replace_path(infile, template = ["{dirname}", "{filebody}-symmetrized.cif"])

    print("")
    print(f"input : {infile}")   
    print(f"log file: {logfile}")
    print(f"output symmetrized CIF file: {convCIFfile}")

#=====================================
# Ready by tkProg
#=====================================
    print("")
    print(f"Read [{infile}] for tkCIF")
    cif = tkCIF()
    cifdata = cif.ReadCIF(infile, find_valid_structure = True)
    if cifdata is None:
        app.terminate(f"Error: Can not read [{infile}]", pause = True)
        exit()
        
    cifdata.Print()
    cry_source = cifdata.GetCrystal()
#    cry_source.print_inf()

#=====================================
# Ready by pymatgen
#=====================================
    print("")
    print(f"Read [{infile}] for pymatgen")
    structure = Structure.from_file(infile)
    print("Structure:", structure)

    print("")
    print(f"Symmetrized:")
    analyzer = SpacegroupAnalyzer(structure)
    symmetrized_structure      = analyzer.get_symmetrized_structure()
    symmetrized_structure_dict = symmetrized_structure.as_dict()

    spg_name, spg_num = symmetrized_structure.get_space_group_info()
    symmetry_ops    = analyzer.get_symmetry_operations()
    symmetry_matrix = [op.as_dict()['matrix'] for op in symmetry_ops]
    nsym = len(symmetry_matrix)
    print("Space group: ", spg_name, spg_num)
    print(f"Number of symmetry options: {nsym}")
#    print("Symmetry operatoins: ", symmetry_matrices)

    print("")
    print("Lattice: ")
    lattice_inf = symmetrized_structure.lattice
    a, b, c, alpha, beta, gamma = lattice_inf.parameters
    volume = lattice_inf.volume
    aij    = lattice_inf.matrix
    print("  Lattice parameters:", a, b, c, alpha, beta, gamma)
    print("  Matrix: ", aij)
    print("  Volume: ", volume)

    print("")
    print("Equivalent sites:")
    eq_sites = symmetrized_structure.equivalent_sites
    for sites in eq_sites:
        composition  = sites[0].species
        for atom_name in composition.keys():
            print("  ", atom_name, sites[0].frac_coords, "  occ=", composition[atom_name])
#        print("dir(sites[0])=", dir(sites[0]))
#        for site in sites:
#            print("  ", site)
    
    print("")
    print("All sites:")
    sites_inf = symmetrized_structure.sites
    for site in sites_inf:
        composition  = site.species
        for atom_name in composition.keys():
            print("  ", atom_name, site.frac_coords, "  occ=", composition[atom_name])

#=====================================
# Save
#=====================================
    print("")
    print(f"Save the symmetrized structure to [{convCIFfile}]")
# This does not save the symmetrized structure. So implement with tkCrytal
#    symmetrized_structure.to(filename = convCIFfile, fmt = "cif")
    out = tkFile(convCIFfile, 'w')
    if out.fp is None:
        app.terminate(f"Error: Can not write to [{convCIFfile}]", pause = True)

    out.write(f"data_{cry_source.path}\n")
    out.write(f"_cell_length_a     {a}\n")
    out.write(f"_cell_length_b     {b}\n")
    out.write(f"_cell_length_c     {c}\n")
    out.write(f"_cell_angle_alpha  {alpha}\n")
    out.write(f"_cell_angle_beta   {beta}\n")
    out.write(f"_cell_angle_gamma  {gamma}\n")
    out.write(f"_cell_volume       {volume}\n")
    out.write(f"\n")
    out.write(f"_symmetry_space_group_name_H-M  '{spg_name}'\n")
    out.write(f"_symmetry_Int_Tables_number     {spg_num}\n")
    out.write(f"_chemical_formula_structural  {cry_source.ChemicalFormula}\n")
    out.write(f"_chemical_formula_sum         {cry_source.ChemicalFormula}\n")
    out.write(f"_cell_formula_units_Z   {cry_source.ChemicalFormulaUnit}\n")

    def to_str(v):
        eps = 1.0e-6
        if v == 0.0:
            return None
        if v == 1.0:
            return '+'
        if v == -1.0:
            return '-'
        if abs(v - 0.5) < eps:
            return '1/2'
        if abs(v + 0.5) < eps:
            return '-1/2'
        if abs(v - 0.25) < eps:
            return '1/4'
        if abs(v + 0.25) < eps:
            return '-1/4'
        if abs(v - 1.0/3.0) < eps:
            return '1/3'
        if abs(v + 1.0/3.0) < eps:
            return '-1/3'
        if abs(v - 2.0/3.0) < eps:
            return '2/3'
        if abs(v + 2.0/3.0) < eps:
            return '-2/3'
        if abs(v - 1.0/6.0) < eps:
            return '1/6'
        if abs(v + 1.0/6.0) < eps:
            return '-1/6'
        if abs(v - 5.0/6.0) < eps:
            return '5/6'
        if abs(v + 5.0/6.0) < eps:
            return '-5/6'
        if v > 0.0:
            return re.sub(r'^\+', '', v)

        return v
            
    def vector_to_str(v):
        sx = to_str(v[0])
        sy = to_str(v[1])
        sz = to_str(v[2])
        s = ''
        if sx:
            s += sx + 'x'
        if sy:
            s += sy + 'y'
        if sz:
            s += sz + 'z'

        return re.sub(r'^\+', '', s)

    def matrix_to_str(m):
        x = vector_to_str(m[0])
        y = vector_to_str(m[1])
        z = vector_to_str(m[2])

        return x, y, z

    out.write(f"\n")
    out.write(f"loop_\n")
    out.write(f" _symmetry_equiv_pos_site_id\n")
    out.write(f" _symmetry_equiv_pos_as_xyz\n")
    for i in range(nsym):
        x, y, z = matrix_to_str(symmetry_matrix[i])
        out.write(f" {i+1:2}  '{x}, {y}, {z}'\n")

    out.write(f"\n")
    out.write(f"loop_\n")
    out.write(f" _atom_site_type_symbol\n")
#    out.write(f" _atom_site_label\n")
    out.write(f" _atom_site_fract_x\n")
    out.write(f" _atom_site_fract_y\n")
    out.write(f" _atom_site_fract_z\n")
    out.write(f" _atom_site_occupancy\n")
    for sites in eq_sites:
        composition  = sites[0].species
        for atom_name in composition.keys():
            pos = sites[0].frac_coords
            occ = composition[atom_name]
            
            out.write(f" {atom_name}")
            out.write(f"    {pos[0]:12.8f}  {pos[1]:12.8f}  {pos[2]:12.8f}")
            out.write(f"  {occ}\n")

    out.close()


    app.terminate(pause = True)


if __name__ == "__main__":
    main()


