#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Lattice conversion program using tkCIF + pymatgen (partial-occupancy aware).
- --conversion prim | romb | hex | orth | MATRIX
  * prim : primitive cell (SpacegroupAnalyzer→primitive)
  * romb : if R hex setting → rhombohedral primitive (Hex→Rhombo)
  * hex  : if R rhombo setting → hexagonal conventional (Rhombo→Hex)
  * orth : if lattice is centered (A/B/C/F/I) → primitive by proper T
  * MATRIX: '(a,b,c)(d,e,f)(g,h,i)' or '(a,b,c,tx)(d,e,f,ty)(g,h,i,tz)'
           entries can be arithmetic (1/3, sqrt(2)/2,...)
- Basis change: V' = T @ V, f' = f @ inv(T) + t ; wrap & merge duplicates.
- Outputs cell volume, effective atom count (sum of occupancies), total mass, atomic & mass densities.
- Confirms equality of densities within eps (default 1e-4).
"""

import argparse
import sys
import os
import re
import math
import numpy as np
from typing import Tuple, Optional, List

from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer


sys.path.append("d:/git/tkProg/tklib/python")

from tklib.tkcrystal.tkcif import tkCIF
from tklib.tkcrystal.tkcif2pymatgen import tkcrystal_to_pmg_structure


def terminate():
    input("\nPress ENTER to terminate>>")
    exit()
    
def initialize():
    p = argparse.ArgumentParser(description="Lattice conversion program using pymatgen.")
    p.add_argument("input_file", type=str, help="Input CIF file path.")
    p.add_argument("-p", "--pymatgen", type=int, default=0,
                   help="Flag to use pymatgen.CifParser. (Default: Use tkCIF.ReadCIF()")
    p.add_argument("-c", "--conversion", type=str, default="prim",
                   help=("Conversion spec: prim | rhomb | hex | orth | "
                         "MATRIX like '(a,b,c)(d,e,f)(g,h,i)' "
                         "or '(a,b,c,tx)(d,e,f,ty)(g,h,i,tz)'. "
                         "Each a,b,c,... can be arithmetic (e.g., 1/2, sqrt(2)/2)."))
    p.add_argument("-d", "--direction", type=str, default="OriginalToConverted",
                   choices=["OriginalToConverted", "ConvertedToOriginal"],
                   help="If a MATRIX is given, reverse by using its inverse when ConvertedToOriginal.")
    p.add_argument("-t", "--sym_tol", type=float, default=0.01,
                   help="Symmetry tolerance for space group analysis.")
    p.add_argument("-x", "--xyz_tol", type=float, default=1.0e-4,
                   help="Tolerance to merge duplicate fractional sites after basis change.")
    p.add_argument("--eps", type=float, default=1.0e-4,
                   help="Relative tolerance for density equality check.")
    return p.parse_args()

# ------------------------- eval-based parser -------------------------

SAFE_GLOBALS = {
    "__builtins__": {},
    "sqrt": math.sqrt,
    "sin": math.sin,
    "cos": math.cos,
    "tan": math.tan,
    "pi": math.pi,
    "e": math.e,
}

def print_matrix(message, T):
    print(message)
    for row in T:
        print(f"      |{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f}|" )

def parse_number(expr: str) -> float:
    """Evaluate a numeric expression safely with eval()."""
    return float(eval(expr, SAFE_GLOBALS, {}))

def parse_conversion_matrix(spec: str):
    s = spec.replace(" ", "")
    groups = re.findall(r'\(([^()]*)\)', s)
    if len(groups) != 3:
        raise ValueError("Matrix spec must have exactly 3 parenthesis groups.")
    T = np.zeros((3,3), dtype=float)
    t = np.zeros(3, dtype=float)
    for i, g in enumerate(groups):
        parts = g.split(',')
        if len(parts) not in (3,4):
            raise ValueError("Each row must have 3 or 4 entries.")
        T[i,0] = parse_number(parts[0])
        T[i,1] = parse_number(parts[1])
        T[i,2] = parse_number(parts[2])
        if len(parts) == 4:
            t[i] = parse_number(parts[3])
    return T, t

# ------------------------- Reporting -------------------------

def site_effective_occupancy_and_mass_amu(site) -> Tuple[float, float]:
    occ = 0.0
    mass = 0.0
    for sp, amt in site.species.items():
        occ += float(amt)
        amu = getattr(sp, "atomic_mass", 0.0)
        mass += float(amt) * float(amu if amu is not None else 0.0)
    return occ, mass

def structure_effective_counts(struct: Structure) -> Tuple[float, float, int]:
    eff_atoms = 0.0
    total_mass = 0.0
    for site in struct.sites:
        occ, mass = site_effective_occupancy_and_mass_amu(site)
        eff_atoms += occ
        total_mass += mass
    return eff_atoms, total_mass, len(struct.sites)

def report_structure(struct: Structure, label: str) -> Tuple[float, float]:
    a, b, c = struct.lattice.abc
    alpha, beta, gamma = struct.lattice.angles
    volume = struct.volume

    eff_atoms, total_mass_amu, n_sites = structure_effective_counts(struct)
    total_mass_g = total_mass_amu * 1.66053906660e-24
    volume_cm3 = volume * 1.0e-24

    rho_atom = eff_atoms / volume if volume > 0 else float("nan")
    rho_mass = total_mass_g / volume_cm3 if volume_cm3 > 0 else float("nan")

    print(f"\n{label} Structure:")
    print(f"  Lattice constants       : a={a:.6f}  b={b:.6f}  c={c:.6f}")
    print(f"  Lattice angles          : alpha={alpha:.2f}  beta={beta:.2f}  gamma={gamma:.2f}")
    print(f"  Volume per cell         : {volume:.6f} Å^3")
    print(f"  Sites (positions)       : {n_sites}")
    print(f"  Effective atoms (Σ occ) : {eff_atoms:.6f}")
    print(f"  Total mass (cell)       : {total_mass_amu:.6f} amu ({total_mass_g:.6e} g)")
    print(f"  Atomic density          : {rho_atom:.6f} atoms/Å^3")
    print(f"  Mass density            : {rho_mass:.6f} g/cm^3")

    return rho_atom, rho_mass

def dump_sites(struct: Structure, label: str):
    print(f"\n--- {label} Structure Sites ---")
    for i, site in enumerate(struct.sites):
        sp_occ = ", ".join([f"{sp.symbol}:{amt:.3f}" for sp, amt in site.species.items()])
        x, y, z = site.frac_coords
        print(f"[{i:3d}] {sp_occ:20s}  frac=({x:8.4f}, {y:8.4f}, {z:8.4f})")

# ------------------------- Symmetry helpers -------------------------

def detect_setting_and_centering(struct: Structure, sym_tol: float) -> Tuple[str, str, str]:
    sga = SpacegroupAnalyzer(struct, symprec=sym_tol)
    spg = sga.get_space_group_symbol()
    first = spg[0] if spg else '-'
    setting = '-'
    if first == 'R':
        gamma = struct.lattice.gamma
        setting = "hexagonal" if np.isclose(gamma, 120.0, atol=0.1) else "rhombohedral"
    centering = first if first in {'P','A','B','C','I','F','R'} else '-'
    return spg, setting, centering

def hex_to_rhombo_T() -> np.ndarray:
    return np.array([[ 2/3,  1/3,  1/3],
                     [-1/3,  1/3,  1/3],
                     [-1/3, -2/3,  1/3]], dtype=float)

def rhombo_to_hex_T() -> np.ndarray:
    return np.array([[ 1, -1,  0],
                     [ 0,  1, -1],
                     [ 1,  1,  1]], dtype=float)

def centering_to_prim_T(centering: str) -> Optional[np.ndarray]:
    c = centering.upper()
    if c == 'F':
        return np.array([[0.0, 0.5, 0.5],
                         [0.5, 0.0, 0.5],
                         [0.5, 0.5, 0.0]], dtype=float)
    if c == 'I':
        return np.array([[-0.5, 0.5, 0.5],
                         [ 0.5,-0.5, 0.5],
                         [ 0.5, 0.5,-0.5]], dtype=float)
    if c == 'A':
        return np.array([[1.0, 0.0, 0.0],
                         [0.0, 0.5, 0.5],
                         [0.0,-0.5, 0.5]], dtype=float)
    if c == 'B':
        return np.array([[0.5, 0.0,-0.5],
                         [0.0, 1.0, 0.0],
                         [0.5, 0.0, 0.5]], dtype=float)
    if c == 'C':
        return np.array([[0.5,-0.5, 0.0],
                         [0.5, 0.5, 0.0],
                         [0.0, 0.0, 1.0]], dtype=float)
    return None

# ------------------------- Basis change core -------------------------
# ... (省略: ここは前回の change_basis_preserving_geometry をそのまま使う) ...
# ------------------------- Main -------------------------

def main():
    args = initialize()
    infile = args.input_file
    conv_spec = args.conversion
    direction = args.direction
    sym_tol = args.sym_tol
    eps = args.eps
    xyz_tol = args.xyz_tol

    if not os.path.exists(infile):
        print(f"Error: Input file '{infile}' not found.")
        terminate()

    print("Lattice conversion using pymatgen")
    print(f"Input file            : {infile}")
    print(f"Use pymatgen.CifParser: {args.pymatgen}")
    print(f"Conversion            : {conv_spec}")
    print(f"Direction             : {direction}")
    print(f"Symmetry tolerance    : {sym_tol}")
    print(f"(x,y,z) tolerance     : {xyz_tol}")
    print("-" * 30)

    if args.pymatgen:
    # Read CIF with occupancy_tolerance=0
        try:
            parser = CifParser(infile)
#            parser = CifParser(infile, occupancy_tolerance = 1.0e-4)
            s_orig = parser.parse_structures(primitive=False)[0]
        except Exception as e:
            print(f"Error in main(): Failed to read {infile}: {e}")
            print(f"   The CIF may include partial occupancy site.\n")
            terminate()
    else:
    # --- tkCIF → pymatgen.Structure ---
        try:
            cif = tkCIF()
            cifdata = cif.ReadCIF(infile, find_valid_structure = True)
            cif.Close()
#            cifdata.Print()
            cry = cifdata.GetCrystal()
            cry.PrintInf()
            s_orig = tkcrystal_to_pmg_structure(cry)
        except Exception as e:
            print(f"Error in main(): Failed to read {infile} with tkCIF: {e}")
            terminate()

    spg, setting, centering = detect_setting_and_centering(s_orig, sym_tol)
    print(f"Space Group: {spg}")
    if setting != '-':
        print(f"Detected setting    : {setting} (R-lattice)")
    print(f"Detected centering  : {centering}")

    rho_atom_orig, rho_mass_orig = report_structure(s_orig, "Original")
    dump_sites(s_orig, "Original")

    T = None
    tvec = np.zeros(3, dtype=float)
    conv_key = conv_spec.strip().lower()

    if conv_key == 'prim':
        print("\n[Conversion] primitive cell via SpacegroupAnalyzer")
        try:
            print("Transform lattice:")
            s_conv = SpacegroupAnalyzer(s_orig, symprec=sym_tol).get_primitive_standard_structure()
        except Exception:
            s_conv = s_orig.get_primitive_structure()
        rho_atom_new, rho_mass_new = report_structure(s_conv, "Converted")
        dump_sites(s_conv, "Converted")
    elif conv_key in ('rhomb', 'hex', 'orth'):
        print("\n[Conversion] auto rule:", conv_key)
        if conv_key == 'rhomb' and centering == 'R' and setting == 'hexagonal':
            T = hex_to_rhombo_T()
        elif conv_key == 'hex' and centering == 'R' and setting == 'rhombohedral':
            T = rhombo_to_hex_T()
        elif conv_key == 'orth':
            T = centering_to_prim_T(centering)
        else:
            T = np.identity(3)
    else:
        print("\n[Conversion] explicit matrix parsing")
        T, tvec = parse_conversion_matrix(conv_spec)
        if direction == "ConvertedToOriginal":
            T = np.linalg.inv(T)
            tvec = -tvec @ T

    if conv_key != 'prim' and T is None:
        print()
        print(f"Error in main(): Inconsistent lattice for concersion to [{conv_key}]")
        print()
        terminate()

    if T is not None:
        print_matrix("Transformation matrix [T]:", T)
        print(f"det(T) = {np.linalg.det(T):.8f}")
        s_conv = change_basis_preserving_geometry(s_orig, T, t=tvec, xyz_tol=xyz_tol)
        rho_atom_new, rho_mass_new = report_structure(s_conv, "Converted")
        dump_sites(s_conv, "Converted")

    print()
    print("Check consistency:")
    if (rho_atom_new - rho_atom_orig) / rho_atom_orig < eps:
        print(f"  Atomic densities are identical ({rho_atom_orig} vs. {rho_atom_new} g/cm3) within eps={eps}")
    else:
        print("#" * 80)
        print(f"  Error!!: Atomic density changed from {rho_atom_orig} to {rho_atom_new} g/cm3")
        print("#" * 80)
        print()

    if (rho_mass_new - rho_mass_orig) / rho_mass_orig < eps:
        print(f"  Mass densities are identical ({rho_mass_orig} vs. {rho_mass_new} g/cm3) within eps={eps}")
    else:
        print("#" * 80)
        print(f"  Error!!: Mass density changed from {rho_mass_orig} to {rho_mass_new} g/cm3")
        print("#" * 80)
        print()
    
    outfile = os.path.splitext(infile)[0] + "_converted.cif"
    s_conv.to(filename=outfile, fmt="cif")
    print(f"\nConverted CIF file saved to: {outfile}")

    if args.pymatgen:
        print()
        print("#"*80)
        print("# Please check OCCUPANCY: pymatgen may fail to read partial occupancies ocrrectly.")
        print("#"*80)

if __name__ == "__main__":
    main()
    input("\nPress ENTER to terminate>>")
    
