import argparse
import sys
import numpy as np
from pprint import pprint
from typing import Tuple, Optional, List, Dict, Any
import re
import math
import os
from numpy import sin, cos, tan, pi

# --- tklibのインポート ---
# IMPORTANT: tklibのディレクトリがsys.pathに含まれている必要があります
# 例: sys.path.append("d:/git/tkProg/tklib/python") 
try:
    from tklib.tkfile import tkFile
    from tklib.tkcrystal.tkcif import tkCIF
    from tklib.tkcrystal.tkcrystal import tkCrystal
    from tklib.tkcrystal.tkatomtype import tkAtomType
    # pymatgenとの相互運用に必要なモジュール
    from tklib.tkcrystal.tkcif2pymatgen import tkcrystal_to_pmg_structure
except ImportError:
    print("エラー: tklibライブラリが見つかりません。", file=sys.stderr)
    print("sys.pathにtklibのパスが正しく追加されているか確認してください。", file=sys.stderr)
    sys.exit(1)

# --- pymatgenのインポート ---
try:
    from pymatgen.core import Structure
    from pymatgen.core.lattice import Lattice
    from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
except ImportError:
    print("エラー: pymatgenライブラリが見つかりません。", file=sys.stderr)
    sys.exit(1)

# --- ヘルパー関数 (共通) ---

def print_matrix(m):
    """3x3行列を見やすいフォーマットで出力します。"""
    # tkCrystalが返す行列のフォーマットに合わせて表示
    print(f"| {m[0][0]:8.4f} {m[0][1]:8.4f} {m[0][2]:8.4f} |")
    print(f"| {m[1][0]:8.4f} {m[1][1]:8.4f} {m[1][2]:8.4f} |")
    print(f"| {m[2][0]:8.4f} {m[2][1]:8.4f} {m[2][2]:8.4f} |")

# --- BVS計算ヘルパー関数 (機能 5用) ---

def read_bv_params(path: str) -> Optional[Dict[str, Any]]:
    """
    指定されたパスからBVSパラメータファイルを読み込み、辞書として返します。
    """
    bv_params = {}
    if not os.path.exists(path):
        print(f"エラー: BVSパラメータファイルが見つかりません: [{path}]", file=sys.stderr)
        return None
        
    print(f"BVSパラメータファイルを読み込み中: [{os.path.basename(path)}]")
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()

    reading = False
    for line in lines:
        line = line.strip()
        if not line:
            continue
        if line == '_valence_param_details':
            reading = True
            continue
        if reading:
            if line.startswith('#'):
                break
            parts = line.split()
            if len(parts) < 7:
                continue
            
            # atom1, Z1, atom2, Z2, r0, B, ref
            atom1, z1_str, atom2, z2_str, r0_str, b_str, ref = parts[:7]
            try:
                # Z1, Z2は使用しないが、元のプログラムに従って読み込む
                # z1 = int(z1_str); z2 = int(z2_str) 
                r0 = float(r0_str); B = float(b_str)
            except ValueError:
                continue
            
            # キーは 'AtomA:AtomB' 形式で保存
            key = f"{atom1}:{atom2}"
            if key not in bv_params:
                bv_params[key] = {
                    'r0': r0, 'B': B, 'ref': ref
                }
    return bv_params


# --- Pymatgen構造に基づくヘルパー関数 (機能 4用) ---

def site_effective_occupancy_and_mass_amu(site) -> Tuple[float, float]:
    """Pymatgen Siteの有効占有数と総質量を計算します。"""
    occ = 0.0
    mass = 0.0
    for sp, amt in site.species.items():
        occ += float(amt)
        amu = getattr(sp, "atomic_mass", 0.0)
        mass += float(amt) * float(amu if amu is not None else 0.0)
    return occ, mass

def structure_effective_counts(struct: Structure) -> Tuple[float, float, int]:
    """Pymatgen Structureの有効原子数、総質量、サイト数を計算します。"""
    eff_atoms = 0.0
    total_mass = 0.0
    for site in struct.sites:
        occ, mass = site_effective_occupancy_and_mass_amu(site)
        eff_atoms += occ
        total_mass += mass
    return eff_atoms, total_mass, len(struct.sites)

def report_pmg_structure_density(struct: Structure, label: str) -> Tuple[float, float]:
    """Pymatgen Structureの密度をレポートします。"""
    a, b, c = struct.lattice.abc
    volume = struct.volume

    eff_atoms, total_mass_amu, n_sites = structure_effective_counts(struct)
    
    total_mass_g = total_mass_amu * 1.66053906660e-24 
    volume_cm3 = volume * 1.0e-24 

    rho_atom = eff_atoms / volume if volume > 0 else float("nan") # atoms/Å^3
    rho_mass = total_mass_g / volume_cm3 if volume_cm3 > 0 else float("nan") # g/cm^3

    print(f"\n[{label} Structure (Pymatgen)]")
    print(f"  体積 (Volume) : {volume:.6f} $\\text{Å}^3$")
    print(f"  有効原子数 ($\\Sigma \\text{occ}$) : {eff_atoms:.6f}")
    print(f"  総質量 (Mass) : {total_mass_amu:.6f} amu")
    print(f"  原子密度 (Atomic $\\rho$): {rho_atom:.6f} atoms/$\\text{Å}^3$")
    print(f"  質量密度 (Mass $\\rho$): {rho_mass:.6f} g/$\\text{cm}^3$")

    return rho_atom, rho_mass


def change_basis_preserving_geometry(structure: Structure,
                                     T: np.ndarray,
                                     t: Optional[np.ndarray] = None,
                                     xyz_tol: float = 1e-6) -> Structure:
    """
    Pymatgen Structureオブジェクトの基底を変換し、重複サイトをマージします。
    """
    print("  => Transformation matrix [T]:")
    print_matrix(T)
    detT = float(np.linalg.det(T))
    if abs(detT) < 1e-12:
        raise ValueError("Transformation matrix is singular (det ~ 0).")

    V = np.array(structure.lattice.matrix, dtype=float)
    Vp = T @ V
    Tinv = np.linalg.inv(T)

    frac = np.array([site.frac_coords for site in structure.sites], dtype=float)
    frac_p = frac @ Tinv
    if t is not None:
        frac_p = frac_p + t.reshape(1,3)

    frac_p = np.mod(frac_p, 1.0)

    # 同一位置の原子のうち１つだけ残してkept_species/coordsに残す
    round_ndp = max(0, int(-np.log10(max(xyz_tol, 1e-12)))) + 1
    kept_coords: List[np.ndarray] = []
    kept_species = []
    seen = set()
    for fc, site in zip(frac_p, structure.sites):
        # 座標と原子種をキーにして重複を判定
        key = (round(fc[0]+xyz_tol, round_ndp), round(fc[1]+xyz_tol, round_ndp), round(fc[2]+xyz_tol, round_ndp),
               site.species)
        if key in seen: continue
        seen.add(key)
        kept_species.append(site.species)
        kept_coords.append(fc)

    new_lattice = Lattice(Vp)
    new_structure = Structure(new_lattice, kept_species, kept_coords, coords_are_cartesian=False)
    new_structure.sort()

    removed = len(structure.sites) - len(new_structure.sites)
    if removed > 0:
        print(f"  [Info] Merged {removed} duplicate sites due to det(T)={detT:.6f}.")
    return new_structure


# --- 解析機能ブロック 1: tkCIFによる読み込みと基本情報 ---

def analyze_tkcif_basic(filepath: str) -> Optional[tkCrystal]:
    """
    tkCIFを使用してCIFファイルを読み込み、基本情報を表示し、tkCrystalオブジェクトを返します。
    """
    print("\n" + "="*50)
    print("--- 1. tkCIFによるファイル読み込みとCIFデータ ---")
    print("="*50)
    
    cif = tkCIF()
    cif.debug = 0 
    
    try:
        cifdata = cif.ReadCIF(filepath, find_valid_structure=1)
    except Exception as e:
        print(f"エラー: tkCIF.ReadCIFの実行中にエラーが発生しました: {e}", file=sys.stderr)
        return None

    print(f"ファイル: [{filepath}]")
    print("\n[tkCIFData.Print()出力]")
    cifdata.Print()

    cry = cifdata.GetCrystal()
    if cry is None:
        print("\nエラー: tkCIFDataからtkCrystalオブジェクトを取得できませんでした。", file=sys.stderr)
        return None
        
    return cry

# --- 解析機能ブロック 2: tkCrystalによる構造情報と密度計算 ---

def analyze_tkcrystal_structure(cry: tkCrystal):
    """
    tkCrystalオブジェクトを用いて構造情報（格子、サイト）と密度を計算し表示します。
    """
    print("\n" + "="*50)
    print("--- 2. tkCrystalによる構造情報と密度 ---")
    print("="*50)

    # 1. 構造情報の表示 (PrintInf())
    print("\n[tkCrystal.PrintInf()出力]")
    print("==============================================")
    cry.PrintInf() # 構造の概要（格子、対称操作、サイトなど）を出力
    print("==============================================")

    # 2. 密度の計算と検証
    
    d = cry.Density() # g/cm3
    ad = cry.AtomDensity() # /cm3
    
    print(f"\n[密度計算結果]")
    print(f"  質量密度 (Density): {d:.6f} g/$\\text{cm}^3$")
    print(f"  原子密度 (AtomDensity): {ad:.2e} atoms/$\\text{cm}^3$")

    # 3. 密度の検証 (元のプログラムのロジックを再現)
    if not (0.8 < d < 20.0):
        print(f"\n❌ 検証失敗: 質量密度が異常な範囲 ({d:.6f} g/cm³) です。")
        
    if not (1.0e22 < ad < 10.0e23):
        print(f"\n❌ 検証失敗: 原子密度が異常な範囲 ({ad:.2e} /cm³) です。")

# --- 解析機能ブロック 3: tkAtomTypeの利用（テスト） ---

def test_tkatomtype():
    """
    tkAtomTypeを使用して原子の情報を取得するテストです。
    """
    print("\n" + "="*50)
    print("--- 3. tkAtomTypeによる原子情報テスト ---")
    print("="*50)

    atom = tkAtomType()
    test_element = 'Au'
    
    try:
        inf = atom.GetAtomInformation(test_element)
        print(f"[{test_element}] の原子情報:")
        print(f"  Atomic Mass: {inf.get('AtomicMass', 'N/A')}")
        print(f"  Covalent Radius: {inf.get('CovRadius', 'N/A')}")
        print(f"  Ion Radius (VI-Coord): {inf.get('IonRadiusVI', 'N/A')}")
        
    except Exception as e:
        print(f"エラー: tkAtomTypeの実行中にエラーが発生しました: {e}", file=sys.stderr)


# --- 解析機能ブロック 4: Pymatgen連携による単位格子変換と密度検証 ---

def perform_pmg_conversion_on_tkcrystal(cry: tkCrystal, sym_tol: float = 0.01, density_eps: float = 1.0e-4):
    """
    tkCrystalをPymatgen Structureに変換し、Primitive Cellへの変換と密度検証を実行します。
    """
    print("\n" + "="*50)
    print("--- 4. Pymatgen連携による単位格子変換と密度検証 ---")
    print("="*50)

    # 1. tkCrystal -> Pymatgen Structure への変換
    try:
        s_orig = tkcrystal_to_pmg_structure(cry)
        print("✅ tkCrystalからPymatgen Structureへの変換に成功しました。")
        print(f"  Pymatgen Structure: {s_orig.formula}, Sites: {len(s_orig)}")
    except Exception as e:
        print(f"❌ エラー: Pymatgen Structureへの変換に失敗しました: {e}", file=sys.stderr)
        return

    # 2. 元のPymatgen Structureのレポート (密度を記録)
    print("\n[ステップ 1] 変換前のPymatgen構造情報と密度計算")
    rho_atom_orig, rho_mass_orig = report_pmg_structure_density(s_orig, "Original")

    # 3. Primitive Cellへの変換
    try:
        print("\n[ステップ 2] Primitive Standard Structureへの変換 (SpacegroupAnalyzerを使用)")
        analyzer = SpacegroupAnalyzer(s_orig, symprec=sym_tol)
        s_conv = analyzer.get_primitive_standard_structure()
    except Exception as e:
        print(f"警告: Standard Structureの取得に失敗。原始格子へのフォールバック: {e}", file=sys.stderr)
        s_conv = s_orig.get_primitive_structure()
        
    # 4. 変換後の構造のレポート
    print("\n[ステップ 3] 変換後のPymatgen構造情報と密度計算")
    rho_atom_new, rho_mass_new = report_pmg_structure_density(s_conv, "Converted (Primitive)")

    # 5. 密度の一致確認
    print("\n[ステップ 4] 密度の一貫性チェック (Pymatgen $\\rightarrow$ Pymatgen)")

    is_atom_rho_consistent = np.isclose(rho_atom_orig, rho_atom_new, rtol=density_eps)
    if is_atom_rho_consistent:
        print(f"✅ 原子密度は一致しています (許容誤差 $\\epsilon={density_eps}$)")
    else:
        print(f"❌ 警告: 原子密度が一致しませんでした。Original: {rho_atom_orig:.6f}, Converted: {rho_atom_new:.6f}")
        
    is_mass_rho_consistent = np.isclose(rho_mass_orig, rho_mass_new, rtol=density_eps)
    if is_mass_rho_consistent:
        print(f"✅ 質量密度は一致しています (許容誤差 $\\epsilon={density_eps}$)")
    else:
        print(f"❌ 警告: 質量密度が一致しませんでした。Original: {rho_mass_orig:.6f}, Converted: {rho_mass_new:.6f}")

    print(f"\n[Info] Primitive Structureの体積は、元の体積の {s_conv.volume / s_orig.volume:.4f} 倍です。")


# --- 解析機能ブロック 5: BVS（結合価数和）計算 ---

def calculate_bond_valence_sums(cry: tkCrystal, max_r: float = 2.7):
    """
    tkCrystalのサイト情報と自作のBVSパラメータを用いて結合価数和を計算します。
    """
    print("\n" + "="*50)
    print("--- 5. BVS（結合価数和）計算 ---")
    print("="*50)
    
    # --- 1. 環境変数とデータベースパスのチェック ---
    tkProg_Root = os.environ.get('tkProg_Root')
    if not tkProg_Root:
        print("❌ エラー: 環境変数 'tkProg_Root' が設定されていません。", file=sys.stderr)
        print("  BVS計算をスキップします。", file=sys.stderr)
        return

    db_path = os.path.join(tkProg_Root, 'tkdb', 'Databases', 'bvparm2020.cif')

    # --- 2. BVSパラメータの読み込み ---
    bv_params = read_bv_params(db_path)
    if bv_params is None:
        return

    # --- 3. BVS計算の実行 ---
    
    # 非対称単位の代表サイトを取得
    asym_sites = cry.AtomSiteList()
    
    # 単位セル内の全原子サイト（対称操作で展開されたものを含む）を取得
    expanded_sites = cry.ExpandedAtomSiteList()

    print(f"\nBVS計算を開始します (カットオフ距離: {max_r:.2f} $\\text{Å}$)")
    
    for i, site1 in enumerate(asym_sites):
        name1 = site1.AtomNameOnly()
        charge1 = site1.Charge()
        occ1 = site1.Occupancy()
        pos1 = site1.Position(1) # [0, 1) 範囲の分率座標

        bvs = 0.0
        
        # サイト名と期待される電荷の表示
        charge_str = f"({charge1:.0f}+)" if charge1 > 0 else f"({charge1:.0f}-)" if charge1 < 0 else ""
        print(f"\n{i:2d}: {name1}{charge_str} (Occ: {occ1:.2f})")
        
        # Expanded sites (全て) をループし、近傍原子を探す
        for site2 in expanded_sites:
            name2 = site2.AtomNameOnly()
            occ2 = site2.Occupancy()
            pos2 = site2.Position(1) 

            # 2サイト間の最短原子間距離を計算 (AllowZero=Falseで同一サイトを無視)
            # irange=[1, 1, 1] は PBC (周期境界条件) の探索範囲
            d = cry.GetNearestInterAtomicDistance(pos1, pos2, AllowZero=False, irange=[1, 1, 1])
            
            # カットオフ距離 $R_{max}$ を超える結合は無視
            if d is None or d > max_r:
                continue

            # BVSパラメータの取得 (順方向または逆方向)
            fwd = f"{name1}:{name2}"
            bwd = f"{name2}:{name1}"
            params = bv_params.get(fwd) or bv_params.get(bwd)
            
            if not params:
                # 警告として出力
                print(f"  ⚠️ Warning: BV params for {fwd} or {bwd} not found, skipping bond at {d:.3f} Å.")
                continue

            r0 = params['r0']; B = params['B']
            if B == 0: continue

            # 結合価数 (Bond Valence) の計算 $s_{ij} = \exp(\frac{R_0 - d_{ij}}{B})$
            bv = np.exp((r0 - d) / B)
            
            # サイト2の占有率 $occ_2$ を乗算して加算
            bvs += bv * occ2

            # 詳細出力
            print(f"  → {name2:<2} $\\text{r}_0={r0:.3f}$, $\\text{B}={B:.2f}$, $\\text{d}={d:.3f}$, $\\text{bv}={bv:.3f}$ ({params['ref']})")

        # 結合価数和の最終出力
        print(f"\n  ◆ Bond Valence Sum (BVS): {bvs:.3f}\n")


# --- メイン関数 ---

def main():
    """
    CLI引数からCIFファイルを受け取り、tkcif, tkcrystalの機能を順次実行します。
    """
    default_cif = "ZnO.cif" 
    
    parser = argparse.ArgumentParser(
        description="tklib (tkCIF/tkCrystal)の機能を実行する構造解析プログラム。",
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument(
        "cif_file", 
        type=str, 
        nargs='?', 
        default=default_cif,
        help=f"解析するCIFファイルのパス。\n(デフォルト: {default_cif} を想定)"
    )
    args = parser.parse_args()

    print(f"tklib参照プログラムを開始します。対象ファイル: {args.cif_file}\n")

    # --- 機能ブロック 1: CIF読み込みとCrystalオブジェクト取得 ---
    cry_object = analyze_tkcif_basic(args.cif_file)
    
    if cry_object:
        # --- 機能ブロック 2: Crystalオブジェクトの解析と密度 ---
        analyze_tkcrystal_structure(cry_object)

        # --- 機能ブロック 4: Pymatgen連携による単位格子変換 ---
        perform_pmg_conversion_on_tkcrystal(cry_object, sym_tol=0.01, density_eps=1.0e-4)
        
        # --- 機能ブロック 5: BVS計算 ---
        # Zn-O結合距離の目安から、カットオフ距離を2.7 Åに設定してテスト実行
        calculate_bond_valence_sums(cry_object, max_r=2.7)
        
    # --- 機能ブロック 3: tkAtomTypeのテスト ---
    test_tkatomtype()
    
    print("\n--- 全ての機能の実行を完了しました。 ---")

if __name__ == "__main__":
    main()
