import os
import sys
import argparse
import docx
from docx.document import Document
from docx.text.paragraph import Paragraph
from docx.table import Table
from itertools import groupby

try:
    from lxml import etree
except ImportError:
    print("\n[エラー] 必要なライブラリ 'lxml' が見つかりません。", file=sys.stderr)
    print("このスクリプトを実行するには、まずライブラリをインストールしてください。", file=sys.stderr)
    sys.exit(f"\nコマンドプロンプトやターミナルで次のコマンドを実行してください:\npip install lxml")

# --- 数式変換のためのグローバル変数と関数 ---
ns = {
    'm': 'http://schemas.openxmlformats.org/officeDocument/2006/math',
    'w': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main',
    'r': 'http://schemas.openxmlformats.org/officeDocument/2006/relationships',
    'wp': 'http://schemas.openxmlformats.org/drawingml/2006/wordprocessingDrawing',
    'a': 'http://schemas.openxmlformats.org/drawingml/2006/main',
}
SYMBOL_MAP = {
    '−': '-', '…': '\\dots', '≤': '\\le', '≥': '\\ge', '≠': '\\ne', '≈': '\\approx',
    '±': '\\pm', '×': '\\times', '÷': '\\div', '⋅': '\\cdot', '∘': '\\circ',
    'α': '\\alpha', 'β': '\\beta', 'γ': '\\gamma', 'δ': '\\delta', 'ε': '\\epsilon',
    'ζ': '\\zeta', 'η': '\\eta', 'θ': '\\theta', 'ι': '\\iota', 'κ': '\\kappa',
    'λ': '\\lambda', 'μ': '\\mu', 'ν': '\\nu', 'ξ': '\\xi', 'ο': 'o', 'π': '\\pi',
    'ρ': '\\rho', 'σ': '\\sigma', 'τ': '\\tau', 'υ': '\\upsilon', 'φ': '\\phi',
    'χ': '\\chi', 'ψ': '\\psi', 'ω': '\\omega', 'Α': 'A', 'Β': 'B', 'Γ': '\\Gamma',
    'Δ': '\\Delta', 'Ε': 'E', 'Ζ': 'Z', 'Η': 'H', 'Θ': '\\Theta', 'Ι': 'I', 'Κ': 'K',
    'Λ': '\\Lambda', 'Μ': 'M', 'Ν': 'N', 'Ξ': '\\Xi', 'Ο': 'O', 'Π': '\\Pi',
    'Ρ': 'P', 'Τ': 'T', 'Υ': '\\Upsilon', 'Φ': '\\Phi', 'Χ': 'X',
    'Ψ': '\\Psi', 'Ω': '\\Omega',
    '∫': '\\int', '∮': '\\oint', '∂': '\\partial', '∇': '\\nabla', '∞': '\\infty',
    '∀': '\\forall', '∃': '\\exists', '∅': '\\emptyset', '∈': '\\in', '∉': '\\notin',
    '⊂': '\\subset', '⊃': '\\supset', '∩': '\\cap', '∪': '\\cup',
    '→': '\\rightarrow', '←': '\\leftarrow', '⇒': '\\Rightarrow', '⇐': '\\Leftarrow',
    '↔': '\\leftrightarrow', '⇔': '\\Leftrightarrow', '↦': '\\mapsto',
    '∑': '\\sum',
    'Σ': '\\sum',
}

def _latex_delimiter(ch: str, side: str = "left") -> str:
    """
    \left と \right の直後に置く区切り文字を LaTeX に適したトークンへ変換する。
    - {, } は \{, \} にエスケープ
    - 〈〉や⟨⟩は \langle, \rangle に変換
    - ︵︶ 等が来た場合のフォールバックはそのまま返す
    side: "left" or "right"（\langle / \rangle の選択に使用）
    """
    if not ch:
        return ch
    # 中括弧は必ずエスケープ
    if ch == '{':
        return r'\{'
    if ch == '}':
        return r'\}'

    # 角括弧（いろいろなユニコードが来ることがある）
    left_angles  = {'⟨', '〈', '〈'}  # U+27E8, U+2329, U+3008
    right_angles = {'⟩', '〉', '〉'}  # U+27E9, U+232A, U+3009
    if ch in left_angles:
        return r'\langle' if side == 'left' else r'\langle'  # left のみ想定だが念のため
    if ch in right_angles:
        return r'\rangle' if side == 'right' else r'\rangle'

    # そのほかは素のまま返す（(), [], | などはそのままでOK）
    return ch

# --- OMML (数式) パーサー ---
def parse_omml_element(element):
    """
    OMML (Office Math Markup Language) のXML要素を再帰的に解析し、
    LaTeX文字列に変換する。
    """
    if element is None:
        return ""
    tag = etree.QName(element).localname

    if tag == 't':
        text = element.text or ''

        # 関数名（sin, cos, ...）の m:nor 判定
        is_function = False
        try:
            parent_r = element.getparent()
            if etree.QName(parent_r).localname == 'r':
                rPr = parent_r.find('m:rPr', ns)
                if rPr is not None and rPr.find('m:nor', ns) is not None:
                    if text in ['sin', 'cos', 'tan', 'lim', 'log', 'ln', 'exp', 'sup', 'inf', 'max', 'min', 'deg']:
                        is_function = True
        except AttributeError:
            pass

        if is_function:
            return f"\\{text} "
        else:
            if text in ['sin', 'cos', 'tan', 'lim', 'log', 'ln', 'exp', 'sup', 'inf', 'max', 'min', 'deg']:
                return f"\\{text} "

            result_parts = []
            for i, char in enumerate(text):
                latex_char = SYMBOL_MAP.get(char, char)
                result_parts.append(latex_char)
                # \pi 直後に英字が続く場合のスペース
                if latex_char.startswith('\\') and len(latex_char) > 1 and latex_char[1].isalpha():
                    if (i + 1 < len(text)) and text[i + 1].isalpha():
                        result_parts.append(' ')
            return "".join(result_parts)

    if tag in ['r', 'e', 'num', 'den', 'sup', 'sub', 'base', 'oMath', 'oMathPara']:
        return "".join(parse_omml_element(child) for child in element)
    if tag == 'f':
        num = parse_omml_element(element.find('m:num', ns))
        den = parse_omml_element(element.find('m:den', ns))
        return f"\\frac{{{num}}}{{{den}}}"
    if tag == 'sSup':
        base = parse_omml_element(element.find('m:e', ns))
        sup = parse_omml_element(element.find('m:sup', ns))
        return f"{{{base}}}^{{{sup}}}"
    if tag == 'sSub':
        base = parse_omml_element(element.find('m:e', ns))
        sub = parse_omml_element(element.find('m:sub', ns))
        return f"{{{base}}}_{{{sub}}}"
    if tag == 'rad':
        base = parse_omml_element(element.find('m:e', ns))
        return f"\\sqrt{{{base}}}"
    if tag == 'd':
        # デリミタ付きの括弧（OMML: m:d）
        content = "".join(parse_omml_element(child) for child in element.findall('m:e', ns))
        dPr = element.find('m:dPr', ns)
        if dPr is not None:
            beg_char_elem = dPr.find('m:begChr', ns)
            end_char_elem = dPr.find('m:endChr', ns)
            beg_char = beg_char_elem.get('{http://schemas.openxmlformats.org/officeDocument/2006/math}val') if beg_char_elem is not None else '('
            end_char = end_char_elem.get('{http://schemas.openxmlformats.org/officeDocument/2006/math}val') if end_char_elem is not None else ')'
        else:
            beg_char, end_char = '(', ')'

        # ★ Pandoc/texmath が通るように区切り文字を安全化 ★
        lb = _latex_delimiter(beg_char, side="left")
        rb = _latex_delimiter(end_char, side="right")

        return f"\\left{lb} {content} \\right{rb}"
    if tag == 'nary':
        sub = parse_omml_element(element.find('m:sub', ns))
        sup = parse_omml_element(element.find('m:sup', ns))
        base = parse_omml_element(element.find('m:e', ns))
        naryPr = element.find('m:naryPr', ns)
        char_elem = naryPr.find('m:chr', ns) if naryPr is not None else None
        char = char_elem.get('{http://schemas.openxmlformats.org/officeDocument/2006/math}val') if char_elem is not None else 'Σ'
        symbol_latex = SYMBOL_MAP.get(char, char)
        sub_latex = f"_{{{sub}}}" if sub else ""
        sup_latex = f"^{{{sup}}}" if sup else ""
        return f"{symbol_latex}{sub_latex}{sup_latex} {base}"
    return "".join(parse_omml_element(child) for child in element)

def omml_to_latex(omath_para_element):
    """
    段落要素 (<w:p>) を受け取り、その中の最初の m:oMath 要素をLaTeXに変換する。
    """
    math_uri = ns.get('m', 'http://schemas.openxmlformats.org/officeDocument/2006/math')
    qualified_tag_name = '{%s}oMath' % math_uri
    omath_element = next(omath_para_element.iterdescendants(qualified_tag_name), None)
    if omath_element is not None:
        return parse_omml_element(omath_element)
    return ""

def is_inline_math(item):
    """
    item (Run または lxml要素) がインライン数式 (m:oMath) かチェックする。
    """
    math_uri = ns.get('m', 'http://schemas.openxmlformats.org/officeDocument/2006/math')
    qualified_tag_name = '{%s}oMath' % math_uri

    # lxml element の場合
    if hasattr(item, 'tag'):
        if item.tag == qualified_tag_name:
            return True
        return next(item.iterdescendants(qualified_tag_name), None) is not None

    # python-docx Run の場合
    if hasattr(item, '_element'):
        r = item._element
        if r.tag == qualified_tag_name:
            return True
        return next(r.iterdescendants(qualified_tag_name), None) is not None

    return False

def paragraph_contains_math(para: Paragraph):
    """
    段落要素に数式 (m:oMath) が含まれているかチェックする。
    """
    math_uri = ns.get('m', 'http://schemas.openxmlformats.org/officeDocument/2006/math')
    qualified_tag_name = '{%s}oMath' % math_uri
    for _ in para._element.iterdescendants(qualified_tag_name):
        return True
    return False

def iter_runs_and_omath_in_order(para: Paragraph):
    """
    段落内で、XMLの出現順に <w:r> は python-docx の Run、
    <m:oMath> は lxml element のまま返すジェネレータ。
    """
    w_r = f'{{{ns["w"]}}}r'
    m_omath = f'{{{ns["m"]}}}oMath'
    runs = list(para.runs)
    i = 0
    # 直下の子要素の順番をそのまま辿る
    for child in para._element.iterchildren():
        if child.tag == w_r:
            if i < len(runs):
                yield runs[i]
                i += 1
        elif child.tag == m_omath:
            yield child
        else:
            continue

# --- メイン処理関数 ---
def process_runs_to_markdown(items):
    """
    Runオブジェクトとlxml要素(m:oMath)の混在リストを受け取り、Markdownを返す。
    """
    def get_item_key(item):
        is_math = is_inline_math(item)
        if is_math:
            return (None, None, None, None, True)
        else:
            # 非数式は Run を想定
            is_bold = getattr(item, 'bold', False)
            is_italic = getattr(item, 'italic', False)
            is_superscript = False
            is_subscript = False
            if hasattr(item, 'font') and item.font is not None:
                is_superscript = getattr(item.font, 'superscript', False)
                is_subscript = getattr(item.font, 'subscript', False)
            return (is_bold, is_italic, is_superscript, is_subscript, False)

    line = ""
    for (is_bold, is_italic, is_superscript, is_subscript, is_math), items_group in groupby(items, key=get_item_key):
        styled_text = ""
        current_items = list(items_group)

        if is_math:
            math_uri = ns.get('m', 'http://schemas.openxmlformats.org/officeDocument/2006/math')
            qualified_tag_name = '{%s}oMath' % math_uri
            for item in current_items:
                omath_element = None
                # lxml の <m:oMath>
                if hasattr(item, 'tag') and item.tag == qualified_tag_name:
                    omath_element = item
                # Run の内部に <m:oMath>
                elif hasattr(item, '_element'):
                    omath_element = next(item._element.iterdescendants(qualified_tag_name), None)
                if omath_element is not None:
                    latex_formula = parse_omml_element(omath_element)
                    styled_text += f"${latex_formula}$"
        else:
            text_chunk = "".join(r.text for r in current_items if hasattr(r, 'text') and r.text)
            if not text_chunk:
                continue
            styled_text = text_chunk
            # 太字・斜体
            if is_bold and is_italic:
                styled_text = f"***{styled_text}***"
            elif is_bold:
                styled_text = f"**{styled_text}**"
            elif is_italic:
                styled_text = f"*{styled_text}*"
            # 上付き・下付き（Pandoc）
            if is_superscript:
                if styled_text.strip():
                    styled_text = f"^{styled_text}^"
            elif is_subscript:
                if styled_text.strip():
                    styled_text = f"~{styled_text}~"

        line += styled_text
    return line

def save_image_from_run(run, doc, media_dir):
    """
    run要素から画像を抽出し、指定されたディレクトリに保存する。
    画像のMarkdownタグを返す。
    """
    try:
        blip_qname = f'{{{ns["a"]}}}blip'
        blip = next(run._element.iterdescendants(blip_qname), None)
        if blip is None:
            return None

        rId = blip.get(f'{{{ns["r"]}}}embed')
        if not rId:
            return None

        image_part = doc.part.related_parts[rId]
        image_filename = os.path.basename(image_part.partname)
        image_path = os.path.join(media_dir, image_filename)

        with open(image_path, "wb") as f:
            f.write(image_part.blob)

        relative_path = os.path.join(os.path.basename(media_dir), image_filename).replace("\\", "/")
        return f"![{image_filename}]({relative_path})"

    except (IndexError, KeyError, AttributeError, IOError) as e:
        print(f"[警告] 画像の抽出に失敗しました: {e}", file=sys.stderr)
        return None

def handle_paragraph(para: Paragraph, doc: Document, media_dir: str, list_counters: dict):
    """
    単一の段落を処理し、Markdown文字列を返す。
    画像、ブロック数式、見出し、リスト、通常テキストを処理する。
    """
    # --- 1. 画像段落の処理 ---
    drawing_qname = f'{{{ns["w"]}}}drawing'
    if next(para._element.iterdescendants(drawing_qname), None) is not None:
        for run in para.runs:
            image_md = save_image_from_run(run, doc, media_dir)
            if image_md:
                return image_md + "\n\n"  # 画像を見つけたらその段落は終了

    # --- 2. ブロック数式 (ディスプレイ数式) の処理 ---
    math_para_qname = f'{{{ns["m"]}}}oMathPara'
    if next(para._element.iterdescendants(math_para_qname), None) is not None:
        latex_formula = omml_to_latex(para._element)
        if latex_formula:
            return f"$$\n{latex_formula.strip()}\n$$\n\n"

    # --- 3. 見出しの処理 ---
    style_name = para.style.name.lower() if para.style and para.style.name else ""
    if 'heading 1' in style_name or '見出し 1' in style_name:
        return f"# {process_runs_to_markdown(list(iter_runs_and_omath_in_order(para)))}\n\n"
    if 'heading 2' in style_name or '見出し 2' in style_name:
        return f"## {process_runs_to_markdown(list(iter_runs_and_omath_in_order(para)))}\n\n"
    if 'heading 3' in style_name or '見出し 3' in style_name:
        return f"### {process_runs_to_markdown(list(iter_runs_and_omath_in_order(para)))}\n\n"
    if 'heading 4' in style_name or '見出し 4' in style_name:
        return f"#### {process_runs_to_markdown(list(iter_runs_and_omath_in_order(para)))}\n\n"

    # --- 4. リストの処理 ---
    prefix = ""
    pPr = para._p.pPr
    if pPr is not None and pPr.numPr is not None and pPr.numPr.numId is not None:
        num_id, ilvl = pPr.numPr.numId.val, (pPr.numPr.ilvl.val if pPr.numPr.ilvl is not None else 0)
        if num_id not in list_counters:
            list_counters[num_id] = {}
        if ilvl not in list_counters[num_id]:
            list_counters[num_id][ilvl] = 0
        list_counters[num_id][ilvl] += 1
        # 下位レベルのカウンターをリセット
        for higher_ilvl in range(ilvl + 1, 10):
            if higher_ilvl in list_counters.get(num_id, {}):
                list_counters[num_id][higher_ilvl] = 0
        prefix = "    " * ilvl + f"{list_counters[num_id][ilvl]}. "
    elif 'list paragraph' in style_name or '箇条書き' in style_name:
        ilvl = pPr.numPr.ilvl.val if pPr is not None and pPr.numPr is not None and pPr.numPr.ilvl is not None else 0
        prefix = "    " * ilvl + "* "

    # --- 5. 通常の段落テキストの処理 ---
    mixed_items = list(iter_runs_and_omath_in_order(para))
    if not mixed_items:
        return "\n"
    line = process_runs_to_markdown(mixed_items)
    if not line.strip():
        return "\n"
    return prefix + line + "\n\n"

def get_formatted_cell_text(cell: docx.table._Cell):
    """
    テーブルのセル内のテキストを、書式を保持したMarkdown文字列として取得する。
    セル内の改行は <br> で処理する。
    """
    cell_md_parts = []
    for para in cell.paragraphs:
        line = process_runs_to_markdown(list(iter_runs_and_omath_in_order(para)))
        cell_md_parts.append(line)
    return "<br>".join(cell_md_parts).strip()

def handle_table(table: Table):
    """
    テーブルをMarkdown形式に変換する。
    セル内の書式（太字、数式など）も処理する。
    """
    markdown = []
    header_cells = [get_formatted_cell_text(cell) for cell in table.rows[0].cells]
    markdown.append("| " + " | ".join(header_cells) + " |")
    markdown.append("| " + " | ".join(["---"] * len(header_cells)) + " |")

    for row in table.rows[1:]:
        data_cells = [get_formatted_cell_text(cell) for cell in row.cells]
        markdown.append("| " + " | ".join(data_cells) + " |")

    return "\n".join(markdown) + "\n\n"

def convert(infile_path: str, outfile_path: str, imagedir: str):
    """
    DOCXファイルをMarkdownに変換するメイン関数。
    """
    if not os.path.exists(infile_path):
        print(f"[エラー] 指定されたファイルが見つかりません: {infile_path}", file=sys.stderr)
        return

    # 画像保存ディレクトリを作成 (出力MDファイルからの相対パスを想定)
    if not os.path.isabs(imagedir):
        md_file_dir = os.path.dirname(outfile_path)
        if md_file_dir:
            media_save_dir = os.path.join(md_file_dir, imagedir)
        else:
            media_save_dir = imagedir
    else:
        media_save_dir = imagedir

    os.makedirs(media_save_dir, exist_ok=True)
    print(f"画像保存ディレクトリ: {media_save_dir}")

    try:
        doc = docx.Document(infile_path)
        markdown_content = []
        list_counters = {}

        if not hasattr(doc, 'iter_inner_content'):
            print("[警告] 'iter_inner_content' が見つかりません。古い python-docx バージョンの可能性があります。", file=sys.stderr)
            print("[警告] 段落と表の順序が正しくない場合があります。", file=sys.stderr)
            body_elements = doc.element.body
        else:
            body_elements = doc.iter_inner_content()

        for block in body_elements:
            if isinstance(block, Paragraph):
                markdown_content.append(handle_paragraph(block, doc, media_save_dir, list_counters))
            elif isinstance(block, Table):
                markdown_content.append(handle_table(block))

        with open(outfile_path, 'w', encoding='utf-8') as md_file:
            md_file.write("".join(markdown_content))

        print(f"\n変換が正常に完了しました。出力ファイル: {outfile_path}")

        # 保存した画像があるかチェック
        try:
            if any(os.scandir(media_save_dir)):
                print(f"画像は {media_save_dir} に保存されました。")
            else:
                os.rmdir(media_save_dir)
        except OSError:
            pass

    except Exception as e:
        print(f"\n[致命的エラー] 変換中にエラーが発生しました: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        sys.exit(1)

def main():
    """ コマンドライン引数を処理し、変換を実行します。 """
    parser = argparse.ArgumentParser(
        description='Wordファイル(.docx)からテキスト、数式、図を抽出し、Markdownに出力します。',
        formatter_class=argparse.RawTextHelpFormatter
    )
    parser.add_argument("-i", "--input", required=True,
                        help="入力するWordファイル名 (例: report.docx)")
    parser.add_argument("-o", "--output", required=True,
                        help="出力するMarkdownファイル名 (例: output.md)")
    parser.add_argument("--imagedir", default="images",
                        help=("画像ファイルを保存するディレクトリ名。\n"
                              "出力MDファイルからの相対パスとして扱われます。\n"
                              "(デフォルト: images)"))
    parser.add_argument("--pause", type=int, default=0,
                        help="終了時にENTERキー入力を要求する場合、0以外を指定 (デフォルト: 0)")
    args = parser.parse_args()

    convert(args.input, args.output, args.imagedir)

    if args.pause != 0:
        input("\n処理が完了しました。Enterキーを押すと終了します...")

if __name__ == "__main__":
    main()
