import argparse
import os
import re
from pptx import Presentation
from lxml import etree

# 名前空間の定義
NAMESPACES = {
    'p': 'http://schemas.openxmlformats.org/presentationml/2006/main',
    'a': 'http://schemas.openxmlformats.org/drawingml/2006/main',
    'm': 'http://schemas.openxmlformats.org/officeDocument/2006/math',
    'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships',
    'a14': 'http://schemas.microsoft.com/office/drawing/2010/main'
}

# Unicode 数学記号をASCII/LaTeXに変換する辞書（積分の拡張・斜体補正を含む）
MATH_UNICODE_MAP = {
    '𝑷': 'P', '𝑽': 'V', '𝑹': 'R', '𝑻': 'T',
    '𝒂': 'a', '𝒃': 'b', '𝒄': 'c', '𝒅': '\\mathrm{d}', '𝒆': 'e', '𝒇': 'f', '𝒈': 'g',
    '𝒉': 'h', '𝒊': 'i', '𝒋': 'j', '𝒌': 'k', '𝒍': 'l', '𝒎': 'm', '𝒏': 'n',
    '𝒙': 'x', '𝒚': 'y', '𝒛': 'z',
    '𝟐': '2', '𝟏': '1', '𝟎': '0',
    '−': '-', '＋': '+', '÷': '/', '×': '*', '⋅': '\\cdot',
    '…': '...', '∞': '\\infty',
    '∑': '\\sum', '∏': '\\prod',
    '∫': '\\int', '∬': '\\iint', '∭': '\\iiint', '∮': '\\oint',
    'α': '\\alpha', 'β': '\\beta', 'γ': '\\gamma', 'δ': '\\delta', 'ε': '\\epsilon',
    'ζ': '\\zeta', 'η': '\\eta', 'θ': '\\theta', 'ι': '\\iota', 'κ': '\\kappa',
    'λ': '\\lambda', 'μ': '\\mu', 'ν': '\\nu', 'ξ': '\\xi', 'π': '\\pi', 'ρ': '\\rho',
    'σ': '\\sigma', 'τ': '\\tau', 'υ': '\\upsilon', 'φ': '\\phi', 'χ': '\\chi',
    'ψ': '\\psi', 'ω': '\\omega',
    'Γ': '\\Gamma', 'Δ': '\\Delta', 'Θ': '\\Theta', 'Λ': '\\Lambda',
    'Ξ': '\\Xi', 'Π': '\\Pi', 'Σ': '\\Sigma', 'Φ': '\\Phi', 'Ψ': '\\Psi', 'Ω': '\\Omega',
    '′': "'", '°': '\\degree', '℃': '\\degree C'
}

# n-ary演算子のOMML文字→LaTeXの対応
NARY_TO_LATEX = {
    '∑': '\\sum', '∏': '\\prod',
    '∫': '\\int', '∬': '\\iint', '∭': '\\iiint', '∮': '\\oint',
    '⋀': '\\bigwedge', '⋁': '\\bigvee', '⋂': '\\bigcap', '⋃': '\\bigcup',
}
OPERATOR_CHARS = set(['∑', '∏', '∫', '∬', '∭', '∮'])


pause = 0

def terminate():
    if pause:
        input("\nPress ENTER to terminate\n")
    exit()

def initialize():
    parser = argparse.ArgumentParser(description="PowerPointファイルからテキスト、数式、図を抽出し、Markdownに出力します。")
    parser.add_argument("-i", "--input", required=True, help="入力するPowerPointファイル名 (例: test.pptx)")
    parser.add_argument("-o", "--output", required=True, help="出力するMarkdownファイル名 (例: output.md)")
    parser.add_argument("--xml", action="store_true", help="数式の元のOMML XMLを出力します。")
    parser.add_argument("--imagedir", default="images", help="画像ファイルを保存するディレクトリ (デフォルト: images)")
    parser.add_argument("--pause", type=int, default=0, help="終了時にENTERキー入力を要求するか (デフォルト: 0)")
    args = parser.parse_args()
    return args

def get_slide_title(slide):
    """スライドからタイトルを抽出します。"""
    for shape in slide.shapes:
        if getattr(shape, "has_text_frame", False) and shape.is_placeholder and shape.placeholder_format.type == 1:
            title_text = shape.text
            if title_text:
                return title_text.strip()
    for shape in slide.shapes:
        if getattr(shape, "has_text_frame", False):
            first_text = (shape.text or "").strip()
            if first_text:
                return first_text.split('\n')[0]
    return "無題のスライド"

def _safe_text_replace_math_unicode(text: str) -> str:
    if not text:
        return ""
    for u, ltx in MATH_UNICODE_MAP.items():
        text = text.replace(u, ltx)
    return text

def _find_first(element, candidates):
    """候補タグ名（'m:sub'や'm:low'など）のうち最初に見つかった要素を返す"""
    for cand in candidates:
        found = element.find(cand, NAMESPACES)
        if found is not None:
            return found
    return None

def _detect_nary_op_char(element):
    """
    n-aryの演算子文字を検出。
    1) 通常: m:naryPr/m:chr@m:val
    2) opEmu（演算子がランにある）: 子のテキストから ∫, ∬, ∭, ∮, ∑, ∏ を拾う
       （ただし m:e, m:sub, m:sup の中は除外）
    """
    op_tag = element.find('m:naryPr/m:chr', NAMESPACES)
    if op_tag is not None:
        val = op_tag.get(f"{{{NAMESPACES['m']}}}val", "")
        if val:
            return val
    ts = element.xpath('.//m:t[not(ancestor::m:e) and not(ancestor::m:sub) and not(ancestor::m:sup)]', namespaces=NAMESPACES)
    for t in ts:
        s = t.text or ""
        for ch in s:
            if ch in OPERATOR_CHARS:
                return ch
    return ""

def omml_to_latex(element):
    """
    OMMLのXML要素を再帰的に解析し、LaTeXに変換します。
    - <m:oMathPara> / <m:oMath> を起点に下位を解析
    - <m:int>, <m:nary>（opEmu含む）, <m:limLow>/<m:limUpp> に対応
    - ルート, 分数, 上下付き, 上下同時, デリミタに対応
    """
    tag = etree.QName(element).localname

    if tag in ('oMath', 'oMathPara'):
        return "".join(omml_to_latex(child) for child in element)

    elif tag == 'f':  # Fraction
        num = element.find('m:num', NAMESPACES)
        den = element.find('m:den', NAMESPACES)
        return f"\\frac{{{omml_to_latex(num)}}}{{{omml_to_latex(den)}}}"

    elif tag == 'rad':  # Root
        deg = element.find('m:deg', NAMESPACES)
        e = element.find('m:e', NAMESPACES)
        if deg is not None:
            return f"\\sqrt[{omml_to_latex(deg)}]{{{omml_to_latex(e)}}}"
        return f"\\sqrt{{{omml_to_latex(e)}}}"

    elif tag == 'sSup':  # Superscript
        e = element.find('m:e', NAMESPACES)
        sup = element.find('m:sup', NAMESPACES)
        return f"{omml_to_latex(e)}^{{{omml_to_latex(sup)}}}"

    elif tag == 'sSub':  # Subscript
        e = element.find('m:e', NAMESPACES)
        sub = element.find('m:sub', NAMESPACES)
        return f"{omml_to_latex(e)}_{{{omml_to_latex(sub)}}}"

    elif tag == 'sSubSup':  # Subscript + Superscript
        e = element.find('m:e', NAMESPACES)
        sub = element.find('m:sub', NAMESPACES)
        sup = element.find('m:sup', NAMESPACES)
        base = omml_to_latex(e)
        sub_l = omml_to_latex(sub) if sub is not None else ""
        sup_l = omml_to_latex(sup) if sup is not None else ""
        return f"{base}_{{{sub_l}}}^{{{sup_l}}}"

    elif tag == 'd':  # Delimiter（括弧など）
        beg_chr = element.find('m:dPr/m:begChr', NAMESPACES)
        end_chr = element.find('m:dPr/m:endChr', NAMESPACES)
        beg = beg_chr.get(f"{{{NAMESPACES['m']}}}val") if beg_chr is not None else "("
        end = end_chr.get(f"{{{NAMESPACES['m']}}}val") if end_chr is not None else ")"
        content = "".join(omml_to_latex(child) for child in element if etree.QName(child).localname != 'dPr')
        return f"{beg}{content}{end}"

    elif tag == 'r':  # Run
        return "".join(omml_to_latex(child) for child in element)

    elif tag == 't':  # Text
        return _safe_text_replace_math_unicode(element.text or "")

    elif tag == 'limLow':
        base = omml_to_latex(element.find('m:e', NAMESPACES))
        low = omml_to_latex(element.find('m:lim', NAMESPACES))
        return f"{base}_{{{low}}}"

    elif tag == 'limUpp':
        base = omml_to_latex(element.find('m:e', NAMESPACES))
        upp = omml_to_latex(element.find('m:lim', NAMESPACES))
        return f"{base}^{{{upp}}}"

    # 積分専用タグ（環境により出ることがある）
    elif tag == 'int':
        lower = _find_first(element, ['m:sub', 'm:low'])
        upper = _find_first(element, ['m:sup', 'm:up'])
        e = element.find('m:e', NAMESPACES)
        lower_ltx = omml_to_latex(lower) if lower is not None else ""
        upper_ltx = omml_to_latex(upper) if upper is not None else ""
        e_ltx = omml_to_latex(e) if e is not None else ""
        if lower_ltx or upper_ltx:
            return f"\\int_{{{lower_ltx}}}^{{{upper_ltx}}}{e_ltx}"
        else:
            return f"\\int {e_ltx}"

    # 汎用 n-ary（和・積・積分など）
    elif tag == 'nary':
        op_char = _detect_nary_op_char(element)
        lower_tag = _find_first(element, ['m:sub', 'm:low'])
        upper_tag = _find_first(element, ['m:sup', 'm:up'])
        content_tag = element.find('m:e', NAMESPACES)

        lower_ltx = omml_to_latex(lower_tag) if lower_tag is not None else ''
        upper_ltx = omml_to_latex(upper_tag) if upper_tag is not None else ''
        content_ltx = omml_to_latex(content_tag) if content_tag is not None else ''

        op_ltx = NARY_TO_LATEX.get(op_char, _safe_text_replace_math_unicode(op_char))
        if not op_ltx and (lower_ltx or upper_ltx):
            # opEmu 検出失敗時の保険：見た目が積分であることが多い
            op_ltx = '\\int'

        if lower_ltx or upper_ltx:
            return f"{op_ltx}_{{{lower_ltx}}}^{{{upper_ltx}}}{content_ltx}"
        else:
            return f"{op_ltx} {content_ltx}".rstrip()

    # 未知のタグまたはその他のコンテナタグ
    return "".join(omml_to_latex(child) for child in element)

def split_latex_blocks(s: str):
    r"""
    LaTeX文字列を \\ （行区切り）で分割して、空要素を除去。
    例: r"\frac{∂L}{∂x}=0 \\ \frac{∂L}{∂y}=0"
        -> ["\\frac{∂L}{∂x}=0", "\\frac{∂L}{\partial y}=0"]
    """
    if not s:
        return []
    parts = [p.strip() for p in re.split(r'\\\\', s) if p.strip()]
    return parts if parts else [s.strip()]

def extract_content_to_markdown(input_pptx, output_md, image_dir, include_xml=False):
    """
    PPTXファイルからコンテンツを抽出し、Markdownファイルに保存します。
    - 数式: oMathPara を優先、oMathPara 配下の oMath は除外して重複回避
    - 複数行の式は 1 行 = 1 ブロック $$ ... $$ に分割（Pandoc対策）
    """
    try:
        presentation = Presentation(input_pptx)
    except Exception as e:
        print(f"エラー: ファイル '{input_pptx}' を開けませんでした。{e}")
        return

    os.makedirs(image_dir, exist_ok=True)

    markdown_output = ""

    for i, slide in enumerate(presentation.slides):
        title = get_slide_title(slide)
        markdown_output += f"# スライド {i + 1}　{title}\n\n"

        # スライドの生XML
        slide_xml = slide.part.blob
        root = etree.fromstring(slide_xml)

        # テキスト
        text_elements = root.xpath('//a:t', namespaces=NAMESPACES)
        if text_elements:
            slide_text = "".join([elem.text for elem in text_elements if elem.text])
            if slide_text.strip():
                markdown_output += f"## テキスト\n\n{slide_text.strip()}\n\n"

        # 数式（ブロック優先: oMathPara と、oMathPara外の単独oMath）
        math_elements = root.xpath(
            '//m:oMathPara | //m:oMath[not(ancestor::m:oMathPara)]',
            namespaces=NAMESPACES
        )
        if math_elements:
            markdown_output += "## 数式\n\n"
            seen_omml = set()
            for math_elem in math_elements:
                # 重複ガード（OMML文字列単位）
                omml_string = etree.tostring(math_elem, encoding='unicode')
                if omml_string in seen_omml:
                    continue
                seen_omml.add(omml_string)

                tag = etree.QName(math_elem).localname

                # oMathPara は直下の oMath を1行ごとに出力
                if tag == 'oMathPara':
                    inner_oms = math_elem.findall('m:oMath', NAMESPACES)
                    if inner_oms:
                        for om in inner_oms:
                            latex_code = omml_to_latex(om)
                            for one_line in split_latex_blocks(latex_code):
                                one_line = one_line.replace('\\mathrm{d}', '\\,\\mathrm{d}')
                                markdown_output += f"$$ {one_line} $$\n\n"
                    else:
                        # フォールバック
                        latex_code = omml_to_latex(math_elem)
                        for one_line in split_latex_blocks(latex_code):
                            one_line = one_line.replace('\\mathrm{d}', '\\,\\mathrm{d}')
                            markdown_output += f"$$ {one_line} $$\n\n"

                    if include_xml:
                        markdown_output += f"**元のOMML XML:**\n```xml\n{omml_string.strip()}\n```\n\n"

                # 単独 oMath はそのまま（ただし \\ があれば分割）
                else:
                    latex_code = omml_to_latex(math_elem)
                    for one_line in split_latex_blocks(latex_code):
                        one_line = one_line.replace('\\mathrm{d}', '\\,\\mathrm{d}')
                        markdown_output += f"$$ {one_line} $$\n\n"
                    if include_xml:
                        markdown_output += f"**元のOMML XML:**\n```xml\n{omml_string.strip()}\n```\n\n"

        # 画像
        image_elements = root.xpath('//a:blip[@r:embed]', namespaces=NAMESPACES)
        if image_elements:
            markdown_output += "## 図\n\n"
            for j, image_elem in enumerate(image_elements):
                r_id = image_elem.get('{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed')
                try:
                    image_part = slide.part.rels[r_id].target_part
                    image_data = image_part.blob
                    image_ext = image_part.content_type.split('/')[-1]
                    image_filename = f"slide{i+1}_image{j+1}.{image_ext}"
                    image_path = os.path.join(image_dir, image_filename)
                    with open(image_path, 'wb') as f:
                        f.write(image_data)
                    markdown_output += f"![スライド{i+1}の図{j+1}]({os.path.join(os.path.basename(image_dir), image_filename)})\n\n"
                except KeyError:
                    markdown_output += f"図のリンクが見つかりませんでした (rId: {r_id})\n\n"
            markdown_output += "\n"

        markdown_output += "---\n\n"

    with open(output_md, "w", encoding="utf-8") as f:
        f.write(markdown_output)

    print(f"抽出が完了しました。結果は '{output_md}' に保存されました。")

def main():
    global pause

    args = initialize()
    pause = args.pause

    if not os.path.exists(args.input):
        print(f"エラー: 指定された入力ファイル '{args.input}' が見つかりません。")
        return

    extract_content_to_markdown(args.input, args.output, args.imagedir, include_xml=args.xml)

if __name__ == "__main__":
    main()
    terminate()
    