# -*- coding: utf-8 -*-
import argparse
import base64
import re
import sys
import os
import time
from pathlib import Path
from typing import Dict, List, Optional

# --- ライブラリのインポートチェック ---
try:
    import win32com.client  # pywin32 (PPTX -> PDF画像化用)
except ImportError:
    print("Import error: win32com. Please install it with: pip install pywin32")
    input("\nPress ENTER to terminate>>\n")
    sys.exit(1)

try:
    import fitz  # PyMuPDF (PDF -> PNG用)
except ImportError:
    print("Import error: fitz. Please install it with: pip install pymupdf")
    input("\nPress ENTER to terminate>>\n")
    sys.exit(1)

try:
    from pptx import Presentation  # テキスト抽出用
    from lxml import etree         # 数式解析用
except ImportError:
    print("Import error: python-pptx or lxml. Please install: pip install python-pptx lxml")
    input("\nPress ENTER to terminate>>\n")
    sys.exit(1)

from openai import OpenAI
try:
    import google.generativeai as genai
except ImportError:
    print("Geminiを利用する場合は 'pip install google-generativeai' が必要です。")
    input("\nPress ENTER to terminate>>\n")
    sys.exit(1)

from tkai_lib import read_ai_config


def terminate():
    if pause:
        input("\nPress ENTER to terminate>>\n")
    exit()

# =========================================================
# 1. Advanced Math Extraction Logic (Provided Code Integrated)
# =========================================================

# 名前空間の定義
NAMESPACES = {
    'p': 'http://schemas.openxmlformats.org/presentationml/2006/main',
    'a': 'http://schemas.openxmlformats.org/drawingml/2006/main',
    'm': 'http://schemas.openxmlformats.org/officeDocument/2006/math',
    'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships',
    'a14': 'http://schemas.microsoft.com/office/drawing/2010/main'
}

# Unicode 数学記号をASCII/LaTeXに変換する辞書
MATH_UNICODE_MAP = {
    '𝑷': 'P', '𝑽': 'V', '𝑹': 'R', '𝑻': 'T',
    '𝒂': 'a', '𝒃': 'b', '𝒄': 'c', '𝒅': '\\mathrm{d}', '𝒆': 'e', '𝒇': 'f', '𝒈': 'g',
    '𝒉': 'h', '𝒊': 'i', '𝒋': 'j', '𝒌': 'k', '𝒍': 'l', '𝒎': 'm', '𝒏': 'n',
    '𝒙': 'x', '𝒚': 'y', '𝒛': 'z',
    '𝟐': '2', '𝟏': '1', '𝟎': '0',
    '−': '-', '＋': '+', '÷': '/', '×': '*', '⋅': '\\cdot',
    '…': '...', '∞': '\\infty',
    '∑': '\\sum', '∏': '\\prod',
    '∫': '\\int', '∬': '\\iint', '∭': '\\iiint', '∮': '\\oint',
    'α': '\\alpha', 'β': '\\beta', 'γ': '\\gamma', 'δ': '\\delta', 'ε': '\\epsilon',
    'ζ': '\\zeta', 'η': '\\eta', 'θ': '\\theta', 'ι': '\\iota', 'κ': '\\kappa',
    'λ': '\\lambda', 'μ': '\\mu', 'ν': '\\nu', 'ξ': '\\xi', 'π': '\\pi', 'ρ': '\\rho',
    'σ': '\\sigma', 'τ': '\\tau', 'υ': '\\upsilon', 'φ': '\\phi', 'χ': '\\chi',
    'ψ': '\\psi', 'ω': '\\omega',
    'Γ': '\\Gamma', 'Δ': '\\Delta', 'Θ': '\\Theta', 'Λ': '\\Lambda',
    'Ξ': '\\Xi', 'Π': '\\Pi', 'Σ': '\\Sigma', 'Φ': '\\Phi', 'Ψ': '\\Psi', 'Ω': '\\Omega',
    '′': "'", '°': '\\degree', '℃': '\\degree C'
}

# n-ary演算子のOMML文字→LaTeXの対応
NARY_TO_LATEX = {
    '∑': '\\sum', '∏': '\\prod',
    '∫': '\\int', '∬': '\\iint', '∭': '\\iiint', '∮': '\\oint',
    '⋀': '\\bigwedge', '⋁': '\\bigvee', '⋂': '\\bigcap', '⋃': '\\bigcup',
}
OPERATOR_CHARS = set(['∑', '∏', '∫', '∬', '∭', '∮'])

prompt_template_openai = """
あなたはプロフェッショナルな講師です。
スライドから抽出されたテキスト情報（Markdown形式）と添付のスライド画像を統合し、
解説を作成してください。

# スライド番号: {slide_no}
# 抽出テキスト情報:
{slide_text}
# 出力言語: {lang}

# 指示:
1. 画像内の図表やレイアウトを考慮しつつ、抽出テキスト内の数式を含めて正確に解説してください。
2. 以下のフォーマットで出力してください。
3. 出力Markdown中の数式は全て、
$$改行
LaTeX改行
$$
のブロック数式で出力してください

改行
## 1. 解説
## 2. 図・グラフの分析
## 3. このスライドから読み取れる詳細な情報、データの議論・予測
改行
"""

prompt_template_gemini = prompt_template_openai


def _safe_text_replace_math_unicode(text: str) -> str:
    if not text:
        return ""
    for u, ltx in MATH_UNICODE_MAP.items():
        text = text.replace(u, ltx)
    return text

def _find_first(element, candidates):
    """候補タグ名（'m:sub'や'm:low'など）のうち最初に見つかった要素を返す"""
    for cand in candidates:
        found = element.find(cand, NAMESPACES)
        if found is not None:
            return found
    return None

def _detect_nary_op_char(element):
    op_tag = element.find('m:naryPr/m:chr', NAMESPACES)
    if op_tag is not None:
        val = op_tag.get(f"{{{NAMESPACES['m']}}}val", "")
        if val:
            return val
    ts = element.xpath('.//m:t[not(ancestor::m:e) and not(ancestor::m:sub) and not(ancestor::m:sup)]', namespaces=NAMESPACES)
    for t in ts:
        s = t.text or ""
        for ch in s:
            if ch in OPERATOR_CHARS:
                return ch
    return ""

def omml_to_latex(element):
    """
    OMMLのXML要素を再帰的に解析し、LaTeXに変換します。
    """
    tag = etree.QName(element).localname

    if tag in ('oMath', 'oMathPara'):
        return "".join(omml_to_latex(child) for child in element)

    elif tag == 'f':  # Fraction
        num = element.find('m:num', NAMESPACES)
        den = element.find('m:den', NAMESPACES)
        return f"\\frac{{{omml_to_latex(num)}}}{{{omml_to_latex(den)}}}"

    elif tag == 'rad':  # Root
        deg = element.find('m:deg', NAMESPACES)
        e = element.find('m:e', NAMESPACES)
        if deg is not None:
            return f"\\sqrt[{omml_to_latex(deg)}]{{{omml_to_latex(e)}}}"
        return f"\\sqrt{{{omml_to_latex(e)}}}"

    elif tag == 'sSup':  # Superscript
        e = element.find('m:e', NAMESPACES)
        sup = element.find('m:sup', NAMESPACES)
        return f"{omml_to_latex(e)}^{{{omml_to_latex(sup)}}}"

    elif tag == 'sSub':  # Subscript
        e = element.find('m:e', NAMESPACES)
        sub = element.find('m:sub', NAMESPACES)
        return f"{omml_to_latex(e)}_{{{omml_to_latex(sub)}}}"

    elif tag == 'sSubSup':  # Subscript + Superscript
        e = element.find('m:e', NAMESPACES)
        sub = element.find('m:sub', NAMESPACES)
        sup = element.find('m:sup', NAMESPACES)
        base = omml_to_latex(e)
        sub_l = omml_to_latex(sub) if sub is not None else ""
        sup_l = omml_to_latex(sup) if sup is not None else ""
        return f"{base}_{{{sub_l}}}^{{{sup_l}}}"

    elif tag == 'd':  # Delimiter
        beg_chr = element.find('m:dPr/m:begChr', NAMESPACES)
        end_chr = element.find('m:dPr/m:endChr', NAMESPACES)
        beg = beg_chr.get(f"{{{NAMESPACES['m']}}}val") if beg_chr is not None else "("
        end = end_chr.get(f"{{{NAMESPACES['m']}}}val") if end_chr is not None else ")"
        content = "".join(omml_to_latex(child) for child in element if etree.QName(child).localname != 'dPr')
        return f"{beg}{content}{end}"

    elif tag == 'r':  # Run
        return "".join(omml_to_latex(child) for child in element)

    elif tag == 't':  # Text
        return _safe_text_replace_math_unicode(element.text or "")

    elif tag == 'limLow':
        base = omml_to_latex(element.find('m:e', NAMESPACES))
        low = omml_to_latex(element.find('m:lim', NAMESPACES))
        return f"{base}_{{{low}}}"

    elif tag == 'limUpp':
        base = omml_to_latex(element.find('m:e', NAMESPACES))
        upp = omml_to_latex(element.find('m:lim', NAMESPACES))
        return f"{base}^{{{upp}}}"

    elif tag == 'int':
        lower = _find_first(element, ['m:sub', 'm:low'])
        upper = _find_first(element, ['m:sup', 'm:up'])
        e = element.find('m:e', NAMESPACES)
        lower_ltx = omml_to_latex(lower) if lower is not None else ""
        upper_ltx = omml_to_latex(upper) if upper is not None else ""
        e_ltx = omml_to_latex(e) if e is not None else ""
        if lower_ltx or upper_ltx:
            return f"\\int_{{{lower_ltx}}}^{{{upper_ltx}}}{e_ltx}"
        else:
            return f"\\int {e_ltx}"

    elif tag == 'nary':
        op_char = _detect_nary_op_char(element)
        lower_tag = _find_first(element, ['m:sub', 'm:low'])
        upper_tag = _find_first(element, ['m:sup', 'm:up'])
        content_tag = element.find('m:e', NAMESPACES)

        lower_ltx = omml_to_latex(lower_tag) if lower_tag is not None else ''
        upper_ltx = omml_to_latex(upper_tag) if upper_tag is not None else ''
        content_ltx = omml_to_latex(content_tag) if content_tag is not None else ''

        op_ltx = NARY_TO_LATEX.get(op_char, _safe_text_replace_math_unicode(op_char))
        if not op_ltx and (lower_ltx or upper_ltx):
            op_ltx = '\\int'

        if lower_ltx or upper_ltx:
            return f"{op_ltx}_{{{lower_ltx}}}^{{{upper_ltx}}}{content_ltx}"
        else:
            return f"{op_ltx} {content_ltx}".rstrip()

    return "".join(omml_to_latex(child) for child in element)

def split_latex_blocks(s: str):
    if not s:
        return []
    parts = [p.strip() for p in re.split(r'\\\\', s) if p.strip()]
    return parts if parts else [s.strip()]

def get_slide_title(slide):
    """スライドからタイトルを抽出します。"""
    for shape in slide.shapes:
        if getattr(shape, "has_text_frame", False) and shape.is_placeholder and shape.placeholder_format.type == 1:
            title_text = shape.text
            if title_text:
                return title_text.strip()
    for shape in slide.shapes:
        if getattr(shape, "has_text_frame", False):
            first_text = (shape.text or "").strip()
            if first_text:
                return first_text.split('\n')[0]
    return "無題のスライド"

def extract_text_and_math_from_pptx(pptx_path: Path) -> Dict[int, str]:
    """
    PPTXからテキストと数式を抽出し、AI解説用にスライド番号ごとの文字列辞書として返す。
    数式は $$ ... $$ のブロック形式に変換される。
    """
    print(f"Extracting text & math from PPTX: {pptx_path.name}...")
    slides_data = {}
    try:
        prs = Presentation(pptx_path)
        for i, slide in enumerate(prs.slides, start=1):
            parts = []
            
            # 1. タイトル
            title = get_slide_title(slide)
            parts.append(f"### Title: {title}")

            # 2. XML解析
            try:
                root = etree.fromstring(slide.part.blob)
                
                # テキスト抽出 (a:t)
                text_elements = root.xpath('//a:t', namespaces=NAMESPACES)
                if text_elements:
                    slide_text = "".join([elem.text for elem in text_elements if elem.text])
                    if slide_text.strip():
                        parts.append(f"## Text Content:\n{slide_text.strip()}")

                # 数式抽出 (oMathPara / oMath)
                math_elements = root.xpath(
                    '//m:oMathPara | //m:oMath[not(ancestor::m:oMathPara)]',
                    namespaces=NAMESPACES
                )
                if math_elements:
                    parts.append("## Mathematical Formulas:")
                    seen_omml = set()
                    for math_elem in math_elements:
                        # 重複回避
                        omml_string = etree.tostring(math_elem, encoding='unicode')
                        if omml_string in seen_omml:
                            continue
                        seen_omml.add(omml_string)

                        tag = etree.QName(math_elem).localname
                        latex_code_lines = []

                        # oMathPara: 複数行の数式の可能性があるため分割
                        if tag == 'oMathPara':
                            inner_oms = math_elem.findall('m:oMath', NAMESPACES)
                            if inner_oms:
                                for om in inner_oms:
                                    latex_code = omml_to_latex(om)
                                    latex_code_lines.extend(split_latex_blocks(latex_code))
                            else:
                                # フォールバック
                                latex_code = omml_to_latex(math_elem)
                                latex_code_lines.extend(split_latex_blocks(latex_code))
                        else:
                            # 単独 oMath
                            latex_code = omml_to_latex(math_elem)
                            latex_code_lines.extend(split_latex_blocks(latex_code))
                        
                        # ブロック数式として追加
                        for line in latex_code_lines:
                            line = line.replace('\\mathrm{d}', '\\,\\mathrm{d}')
                            parts.append(f"$$ {line} $$")
            
            except Exception as e:
                print(f"XML Parsing Error on Slide {i}: {e}")
                # フォールバック: 標準テキストのみ
                shape_text = []
                for shape in slide.shapes:
                    if hasattr(shape, "text"):
                        shape_text.append(shape.text)
                parts.append("\n".join(shape_text))

            slides_data[i] = "\n\n".join(parts)

    except Exception as e:
        print(f"Error reading PPTX: {e}")

    return slides_data


# =========================================================
# 2. Text Parsing (Existing Markdown File Support)
# =========================================================
SLIDE_HEADER_RE = re.compile(r"^\s*#\s*Slide\s+(\d+)", re.IGNORECASE)

def load_slide_texts(txt_path: Path) -> Dict[int, str]:
    slides: Dict[int, str] = {}
    if not txt_path or not txt_path.exists():
        return {}

    print(f"Loading text from file: {txt_path.name}")
    current_no: Optional[int] = None
    buf: List[str] = []
    
    with txt_path.open("r", encoding="utf-8") as f:
        for line in f:
            m = SLIDE_HEADER_RE.match(line)
            if m:
                if current_no is not None: slides[current_no] = "".join(buf).strip()
                current_no = int(m.group(1)); buf = []
            else:
                buf.append(line)
    if current_no is not None: slides[current_no] = "".join(buf).strip()
    return slides


# =========================================================
# 3. PDF/Image Conversion
# =========================================================
def pptx_to_pdf(pptx_path: Path, pdf_path: Path, visible: bool = False) -> None:
    abs_pptx = str(pptx_path.resolve())
    abs_pdf = str(pdf_path.resolve())
    if pdf_path.exists(): return

    print(f"Converting PPTX to PDF...")
    try:
        app = win32com.client.Dispatch("PowerPoint.Application")
        if not visible: app.DisplayAlerts = 0
        presentation = app.Presentations.Open(abs_pptx, ReadOnly=True, Untitled=False, WithWindow=visible)
        try:
            presentation.SaveAs(abs_pdf, 32)
        finally:
            presentation.Close()
            if not visible: app.Quit()
    except Exception as e:
        print(f"PPTX->PDF Error: {e}"); raise e

def convert_pdf_to_images(pdf_path: Path, out_dir: Path) -> List[Path]:
    if not out_dir.exists(): out_dir.mkdir(parents=True, exist_ok=True)
    doc = fitz.open(str(pdf_path.resolve()))
    paths = []
    print(f"Converting PDF to Images...")
    for i in range(len(doc)):
        mat = fitz.Matrix(2.0, 2.0) # 高解像度化
        pix = doc.load_page(i).get_pixmap(matrix=mat)
        p = out_dir / f"slide_{i+1:03d}.png"
        pix.save(str(p))
        paths.append(p)
    doc.close()
    return paths

# =========================================================
# 4. AI Logic
# =========================================================
def call_gemini(model_name, slide_no, slide_text, image_path, lang):
#    if os.environ.get('GEMINI_API_KEY', None):
#        print(f"call_gemini: API_KEY={os.environ['GEMINI_API_KEY']}")
#    else:
#        print(f"call_gemini: API_KEY={os.environ.get('GOOGLE_API_KEY', None)}")

    prompt = prompt_template_gemini\
        .replace("{slide_no}", f"{slide_no}")\
        .replace("{slide_text}", slide_text)\
        .replace("{lang}", lang)\
        .strip()
#    print("  prompt:", prompt)

    model = genai.GenerativeModel(model_name)
    img_data = {'mime_type': 'image/png', 'data': image_path.read_bytes()}

    try:
        res = model.generate_content([prompt, img_data])
        return res.text
    except Exception as e:
        return f"Gemini Error: {e}"

def call_openai(client, model, slide_no, slide_text, image_path, lang):
#    print(f"\ncall_openai: API_KEY={os.environ.get('OPENAI_API_KEY', None)}")
    b64 = base64.b64encode(image_path.read_bytes()).decode('utf-8')
    prompt = prompt_template_openai\
        .replace("{slide_no}", f"{slide_no}")\
        .replace("{slide_text}", slide_text)\
        .replace("{lang}", lang)\
        .strip()
#    print("  prompt:", prompt)

    res = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": [{"type":"text","text":prompt}, {"type":"image_url","image_url":{"url":f"data:image/png;base64,{b64}"}}]}]
    )
    return res.choices[0].message.content

def parse_args():
    parser = argparse.ArgumentParser(description="PPTX -> Text/Math -> AI Explanation")
    parser.add_argument("infile", type=str, help="Path to input PPTX")
    parser.add_argument("--txt", type=str, default=None, help="Optional: Path to existing text file")
    
    openai_model  = os.getenv("OPENAI_MODEL",  "gpt-4o")
    openai_model5 = os.getenv("OPENAI_MODEL5", "gpt-5.2")
    if os.getenv("GEMINI_MODEL", None):
        google_model = os.getenv("GEMINI_MODEL")
    else:
        google_model = os.getenv("GOOGLE_MODEL", "gemini-2.5-flash")

    parser.add_argument("--api", "-a", type=str, default="gemini", choices=["gemini", "google", "openai5", "openai"])
    parser.add_argument('--model', default=None, help='明示モデル名の指定（apiごとに適用先を切替）')
    parser.add_argument('--google_model',  default=google_model)
    parser.add_argument("--openai_model",  default=openai_model)
    parser.add_argument('--openai_model5', default=openai_model5)

    parser.add_argument("--visible", action="store_true", help="Show PPT window")
    parser.add_argument("--language", type=str, default="Japanese")
    args = parser.parse_args()

    args.openai_key = os.getenv("OPENAI_API_KEY")
    args.gemini_key = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY")

    if args.model:
        if args.api == 'openai5': args.openai_model5 = args.model
        elif args.api == 'openai': args.openai_model = args.model
        elif args.api in ('gemini', 'google'): args.google_model = args.model

    return parser, args

# =========================================================
# Main
# =========================================================
def main():
    parser, args = parse_args()
    
    infile_path = Path(args.infile)
    work_dir = infile_path.parent / (infile_path.stem + "_work")
    img_dir = work_dir / "slides_png"
    work_dir.mkdir(exist_ok=True)

    # 1. Text/Math Extraction
    slide_texts = {}
    if args.txt and Path(args.txt).exists():
        slide_texts = load_slide_texts(Path(args.txt))
    elif infile_path.suffix.lower() == ".pptx":
        # PPTXから高度な数式抽出を実行
        slide_texts = extract_text_and_math_from_pptx(infile_path)
    
    # 2. PDF/Image Conversion
    if infile_path.suffix.lower() == ".pptx":
        pdf_path = work_dir / (infile_path.stem + ".pdf")
        pptx_to_pdf(infile_path, pdf_path, visible=args.visible)
    else:
        # PDF入力の場合
        pdf_path = infile_path

    images = convert_pdf_to_images(pdf_path, img_dir)

    # 3. AI Setup
    if args.api == "openai5":
        model = args.openai_model5
    elif args.api == "openai":
        model = args.openai_model
    else:
        model = args.google_model
    print(f"Using API: {args.api} / Model: {model}")

    client_openai = None
    if args.api == "openai" or args.api == "openai5":
        if args.openai_key is None:
            print("Error: OPENAI_API_KEY missing.")
            sys.exit(1)
        client_openai = OpenAI()
    elif args.api == "gemini":
        if args.gemini_key is None:
            print("Error: GEMINI_KEY and GOOGLE_API_KEY missing.")
            sys.exit(1)
        genai.configure(api_key=args.gemini_key)

    # 4. Processing
    out_md = infile_path.with_name(infile_path.stem + "_explanation.md")
    with open(out_md, "w", encoding="utf-8") as f:
        f.write(f"# Analysis Report: {infile_path.name}\n\n---\n")

    for i, img_p in enumerate(images, start=1):
        print(f"Processing Slide {i}/{len(images)}...", end=" ", flush=True)
        txt = slide_texts.get(i, "(テキスト情報なし)")

        try:
            if args.api in ("openai5", "openai"):
                content = call_openai(client_openai, model, i, txt, img_p, args.language)
            else:
                content = call_gemini(model, i, txt, img_p, args.language)

            chunk = f"# Slide {i}\n\n{content}\n\n---\n"
            print("Done.")
            time.sleep(1)

        except Exception as e:
            print(f"Error: {e}")
            chunk = f"# Slide {i}\n\nError: {e}\n\n---\n"

    # ✅ 毎回 append モードで開いて書き込む
        with open(out_md, "a", encoding="utf-8") as f:
            f.write(chunk)

    print(f"\nSaved to: {out_md}")


if __name__ == "__main__":
    main()
    terminate()
    