# cif_templater.py
"""
CIFテンプレート・メタデータ生成モジュール

[概要]
このスクリプトは、マスタープログラム (`run_workflow.py`) から呼び出されることを
前提とした「部品（モジュール）」です。
指定されたCIFファイルを解析し、後続のCIF生成プロセスで必要となる2つのファイルを
自動で生成します。

[主な機能]
1. Jinja2テンプレートファイルの生成 (.cif.j2):
   元のCIFファイル内の原子情報を、置換用の変数 (例: {{ cation_1.symbol }}) に
   置き換えたテンプレートを作成します。

2. メタデータファイルの生成 (.meta.json):
   CIFファイルから自動で解析した以下の重要な情報をJSONファイルとして保存します。
   - 原子の役割 (例: 'cation_1')
   - 各役割の価数 (例: 4)
   - 正確な組成比 (例: {'cation_1': 1, 'anion_1': 2})

[使い方]
このスクリプトは通常、直接実行しません。`run_workflow.py`が内部で
`CifTemplateGenerator`クラスをインポートして使用します。
デバッグ目的で単独実行することも可能ですが、その場合はコードの修正が必要です。
"""
import os
import re
import json
from collections import defaultdict

class CifTemplateGenerator:
    # Class content is unchanged...
    def __init__(self, cif_path):
        if not os.path.exists(cif_path):
            raise FileNotFoundError(f"エラー: 指定されたCIFファイルが見つかりません: {cif_path}")
        self.cif_path = cif_path
        self.template_path = f"{cif_path}.j2"
        self.meta_path = f"{cif_path}.meta.json"
        self.original_content = ""
        self.atom_roles = []
        self.stoichiometry = {}

    def run(self):
        """Executes the parsing and generation steps."""
        self._parse()
        self._generate_template()
        self._save_metadata()
        return self.template_path, self.meta_path

    def _parse_cif_data(self):
        """Parses the entire CIF file into single values and loop blocks."""
        parsed_data = {'loops': [], 'single_values': {}}
        lines = self.original_content.splitlines()
        i = 0
        while i < len(lines):
            line = lines[i].strip()
            if not line or line.startswith('#'):
                i += 1
                continue
            
            if line.lower() == 'loop_':
                i += 1
                loop_block = {'keys': [], 'values': []}
                while i < len(lines) and lines[i].strip().startswith('_'):
                    loop_block['keys'].append(lines[i].strip())
                    i += 1
                while i < len(lines):
                    line = lines[i].strip()
                    if not line or line.startswith('_') or line.startswith('loop_') or line.startswith('#'):
                        break
                    values = re.findall(r"'[^']*'|\S+", line)
                    loop_block['values'].append([v.strip("'") for v in values])
                    i += 1
                parsed_data['loops'].append(loop_block)
                continue
            elif line.startswith('_'):
                parts = re.findall(r"'[^']*'|\S+", line)
                if len(parts) > 1:
                    key = parts[0]
                    value = ' '.join(parts[1:]).strip("'")
                    parsed_data['single_values'][key] = value
            i += 1
        return parsed_data

    def _parse(self):
        """Parses the CIF data to identify atomic roles and stoichiometry."""
        print(f"ℹ️  CIFファイル '{self.cif_path}' を解析しています...")
        with open(self.cif_path, 'r', encoding='utf-8') as f:
            self.original_content = f.read()

        cif_data = self._parse_cif_data()
        
        z_str = cif_data['single_values'].get('_cell_formula_units_Z', '1')
        Z = int(re.match(r'\d+', z_str).group(0))

        oxidation_loop = next((l for l in cif_data['loops'] if '_atom_type_oxidation_number' in l['keys']), None)
        if not oxidation_loop:
            raise ValueError("'_atom_type_oxidation_number' を含むループブロックが見つかりません。")
        ox_symbol_idx = oxidation_loop['keys'].index('_atom_type_symbol')
        ox_number_idx = oxidation_loop['keys'].index('_atom_type_oxidation_number')
        symbol_to_valence = {item[ox_symbol_idx]: int(item[ox_number_idx]) for item in oxidation_loop['values']}

        atom_site_loop = next((l for l in cif_data['loops'] if '_atom_site_label' in l['keys']), None)
        if not atom_site_loop:
            raise ValueError("'_atom_site_label' を含むループブロックが見つかりません。")
        
        atom_site_keys = atom_site_loop['keys']
        if '_atom_site_type_symbol' in atom_site_keys:
            symbol_key_idx = atom_site_keys.index('_atom_site_type_symbol')
            get_symbol = lambda site_data: re.match(r'[A-Za-z]+', site_data[symbol_key_idx]).group(0)
        elif '_atom_site_label' in atom_site_keys:
            symbol_key_idx = atom_site_keys.index('_atom_site_label')
            get_symbol = lambda site_data: re.match(r'[A-Za-z]+', site_data[symbol_key_idx]).group(0)
        else:
            raise ValueError("元素記号を特定できませんでした。")

        fract_x_idx = atom_site_keys.index('_atom_site_fract_x')
        valid_atom_site_values = [item for item in atom_site_loop['values'] if len(item) > fract_x_idx and item[fract_x_idx].split('(')[0].replace('.', '', 1).isdigit()]
        
        label_key_idx = atom_site_keys.index('_atom_site_label')
        multiplicity_idx = atom_site_keys.index('_atom_site_symmetry_multiplicity')
        
        symbol_to_labels = defaultdict(list)
        temp_stoichiometry = defaultdict(float)

        for item in valid_atom_site_values:
            symbol = get_symbol(item)
            label = item[label_key_idx]
            multiplicity = int(re.match(r'\d+', item[multiplicity_idx]).group(0))
            temp_stoichiometry[symbol] += float(multiplicity) / Z
            if label not in symbol_to_labels[symbol]:
                symbol_to_labels[symbol].append(label)
        
        unique_symbols = sorted(list(symbol_to_labels.keys()))
        cation_count, anion_count = 0, 0
        
        for symbol in unique_symbols:
            valence = next((val for key, val in symbol_to_valence.items() if key.startswith(symbol)), None)
            if valence is None:
                raise ValueError(f"元素 {symbol} の価数を特定できませんでした。")
            
            role_name = ""
            if valence > 0:
                cation_count += 1
                role_name = f"cation_{cation_count}"
            else:
                anion_count += 1
                role_name = f"anion_{anion_count}"
            
            self.atom_roles.append({
                'role_name': role_name,
                'original_symbol': symbol,
                'valence': valence,
                'original_labels': symbol_to_labels[symbol]
            })
            self.stoichiometry[role_name] = int(round(temp_stoichiometry[symbol]))
        
        print("✅ 解析完了。")
    
    def _generate_template(self):
        """Generates the Jinja2 template file from the parsed data."""
        template_content = self.original_content
        template_content = re.sub(r"^_chemical_name_common.*$", "_chemical_name_common '{{ chemical_name_common }}'", template_content, flags=re.MULTILINE)
        template_content = re.sub(r"^_chemical_formula_structural.*$", "_chemical_formula_structural '{{ chemical_formula_structural }}'", template_content, flags=re.MULTILINE)
        template_content = re.sub(r"^_chemical_formula_sum.*$", "_chemical_formula_sum '{{ chemical_formula_sum }}'", template_content, flags=re.MULTILINE)

        for role in self.atom_roles:
            symbol, valence, role_name = role['original_symbol'], role['valence'], role['role_name']
            valence_sign = '+' if valence > 0 else '-'
            original_valence_str = f"{symbol}{abs(valence)}{valence_sign}"

            if valence > 0:
                jinja_type_symbol = "{{ " + role_name + ".symbol }}{{ " + role_name + ".valence }}" + valence_sign
            else:
                jinja_type_symbol = "{{ " + role_name + ".symbol }}{{ " + role_name + ".valence|abs }}" + valence_sign
            
            replacement_for_atom_type = jinja_type_symbol + " {{ " + role_name + ".valence }}"
            template_content = re.sub(f"^{re.escape(original_valence_str)}\\s+{valence}", replacement_for_atom_type, template_content, flags=re.MULTILINE)
            
            replacement_for_atom_site = " " + jinja_type_symbol + " "
            template_content = re.sub(f"\\s{re.escape(original_valence_str)}\\s", replacement_for_atom_site, template_content)
            
            for original_label in role['original_labels']:
                numeric_suffix = re.search(r'\d*$', original_label).group(0)
                jinja_label = "{{ " + role_name + ".symbol }}" + numeric_suffix
                template_content = re.sub(f"^{re.escape(original_label)}(\\s+)", f"{jinja_label}\\1", template_content, flags=re.MULTILINE)

        with open(self.template_path, 'w', encoding='utf-8') as f:
            f.write(template_content)
        print(f"✅ テンプレートファイルを作成しました: {self.template_path}")

    def _save_metadata(self):
        """Saves parsed roles and stoichiometry to a JSON file."""
        metadata = {
            'roles': {role['role_name']: {'valence': role['valence']} for role in self.atom_roles},
            'stoichiometry': self.stoichiometry
        }
        with open(self.meta_path, 'w', encoding='utf-8') as f:
            json.dump(metadata, f, indent=2)
        print(f"✅ メタデータファイルを作成しました: {self.meta_path}")