#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Slides (PPTX) + Zoom audio/video -> Markdown textbook chapters.
- Extracts slide text, speaker notes, and embedded images from PPTX
- Transcribes audio/video via local Whisper (openai-whisper) or faster-whisper (both local)
- Emits per-lecture Markdown and a combined book.md (+ SRT/VTT transcripts)

Author: ChatGPT (starter kit, Whisper-local enabled)
License: MIT
"""
import argparse
import os
import sys
import shutil
from pathlib import Path
import json
from typing import List, Dict, Optional

from pptx import Presentation

# Optional imports (checked at runtime)
_HAVE_FASTER = False
_HAVE_OPENAI_WHISPER = False

import importlib.util
_HAVE_FASTER = importlib.util.find_spec("faster_whisper") is not None
_HAVE_OPENAI_WHISPER = importlib.util.find_spec("whisper") is not None

"""
try:
    import faster_whisper  # type: ignore  # noqa: F401
    _HAVE_FASTER = True
except Exception:
    pass
try:
    import whisper  # type: ignore  # noqa: F401
    _HAVE_OPENAI_WHISPER = True
except Exception:
    pass
"""

AUDIO_EXTS = {'.mp3', '.m4a', '.wav', '.flac', '.mp4', '.mkv', '.m4v', '.mov'}


# ---------- Utilities ----------
def get_device(arg_device: str) -> str:
    if arg_device != "auto":
        return arg_device
    try:
        import torch  # type: ignore
        return "cuda" if torch.cuda.is_available() else "cpu"
    except Exception:
        return "cpu"


def ensure_newline(s: str) -> str:
    return s if s.endswith("\n") else s + "\n"


def seconds_to_timestamp(t: float) -> str:
    # SRT timestamp (HH:MM:SS,mmm)
    if t is None:
        t = 0.0
    hours = int(t // 3600)
    minutes = int((t % 3600) // 60)
    seconds = int(t % 60)
    millis = int(round((t - int(t)) * 1000))
    return f"{hours:02d}:{minutes:02d}:{seconds:02d},{millis:03d}"


def write_srt(segments, path: Path, from_whisper: bool):
    """
    segments:
      - faster-whisper: iterable with .start, .end, .text
      - openai-whisper: list of dicts with 'start','end','text'
    """
    lines = []
    idx = 1
    for seg in segments:
        if from_whisper:
            start = seg.get("start")
            end = seg.get("end")
            text = seg.get("text", "").strip()
        else:
            start = getattr(seg, "start", None)
            end = getattr(seg, "end", None)
            text = getattr(seg, "text", "").strip()
        lines.append(str(idx))
        lines.append(f"{seconds_to_timestamp(start)} --> {seconds_to_timestamp(end)}")
        lines.append(text)
        lines.append("")  # blank line
        idx += 1
    path.write_text("\n".join(lines), encoding="utf-8")


def write_vtt(segments, path: Path, from_whisper: bool):
    """
    WebVTT uses '.' for milliseconds and "WEBVTT" header.
    """
    def to_vtt_ts(t: float) -> str:
        if t is None:
            t = 0.0
        hours = int(t // 3600)
        minutes = int((t % 3600) // 60)
        seconds = int(t % 60)
        millis = int(round((t - int(t)) * 1000))
        return f"{hours:02d}:{minutes:02d}:{seconds:02d}.{millis:03d}"

    lines = ["WEBVTT", ""]
    for seg in segments:
        if from_whisper:
            start = seg.get("start")
            end = seg.get("end")
            text = seg.get("text", "").strip()
        else:
            start = getattr(seg, "start", None)
            end = getattr(seg, "end", None)
            text = getattr(seg, "text", "").strip()
        lines.append(f"{to_vtt_ts(start)} --> {to_vtt_ts(end)}")
        lines.append(text)
        lines.append("")
    path.write_text("\n".join(lines), encoding="utf-8")


# ---------- Discovery ----------
def find_lecture_folders(input_root: Path) -> List[Path]:
    """Return subfolders that contain at least one .pptx."""
    print()
    print(f"Find lecture folders from [{input_root}]")
    folders = []
    for p in sorted(input_root.iterdir()):
        if p.is_dir():
            if any(f.suffix.lower() == '.pptx' for f in p.iterdir() if f.is_file()):
                folders.append(p)
    return folders


def find_first(p: Path, exts: set) -> Optional[Path]:
    for f in sorted(p.iterdir()):
        if f.is_file() and f.suffix.lower() in exts:
            return f
    return None


# ---------- PPTX Extraction ----------
def extract_pptx(pptx_path: Path, out_dir: Path) -> Dict:
    """
    Extract slide-wise text, notes, and embedded images.
    Returns:
      { "slides": [ {"title": str, "bullets": [..], "notes": str, "images": ["images/img_0001.png", ...]} ] }
    """
    prs = Presentation(str(pptx_path))
    slides_data = []
    images_dir = out_dir / "images"
    images_dir.mkdir(parents=True, exist_ok=True)

    image_counter = 1

    for idx, slide in enumerate(prs.slides, start=1):
        title_text = ""
        bullets: List[str] = []
        notes_text = ""
        saved_images = []

        for shape in slide.shapes:
            # Embedded images (13 == PICTURE)
            try:
                if getattr(shape, "shape_type", None) == 13 and hasattr(shape, "image"):
                    try:
                        img = shape.image
                        image_ext = img.ext
                        image_bytes = img.blob
                        img_name = f"img_{image_counter:04d}.{image_ext}"
                        with open(images_dir / img_name, "wb") as fw:
                            fw.write(image_bytes)
                        saved_images.append(str(Path("images") / img_name))
                        image_counter += 1
                    except Exception:
                        pass
            except Exception:
                pass

            # Text (title/bullets)
            if getattr(shape, "has_text_frame", False) and shape.text_frame is not None:
                try:
                    if getattr(shape, "is_placeholder", False) and "title" in str(shape.placeholder_format.type).lower():
                        if not title_text:
                            title_text = (shape.text or "").strip()
                            continue
                except Exception:
                    pass
                text = (shape.text or "").strip()
                if text:
                    for line in text.splitlines():
                        s = line.strip()
                        if s:
                            bullets.append(s)

        # Speaker notes
        notes_slide = slide.notes_slide if slide.has_notes_slide else None
        if notes_slide and notes_slide.notes_text_frame:
            notes_text = (notes_slide.notes_text_frame.text or "").strip()

        slides_data.append({
            "title": title_text or f"Slide {idx}",
            "bullets": bullets,
            "notes": notes_text,
            "images": saved_images,
        })

    return {"slides": slides_data}


def copy_slides_png_if_exists(lecture_dir: Path, out_dir: Path) -> List[str]:
    """
    If user pre-exported full-slide PNGs into lecture_dir/slides_png, copy them to out_dir/slides_png.
    Return list of relative paths.
    """
    src = lecture_dir / "slides_png"
    dst = out_dir / "slides_png"
    collected = []
    if src.is_dir():
        dst.mkdir(parents=True, exist_ok=True)
        for f in sorted(src.iterdir()):
            if f.suffix.lower() in {'.png', '.jpg', '.jpeg'}:
                shutil.copy2(f, dst / f.name)
                collected.append(str(Path("slides_png") / f.name))
    return collected


# ---------- Transcription ----------
def transcribe_audio(
    audio_path: Path,
    out_dir: Path,
    lang: str,
    whisper_model: str,
    asr_backend: str,
    device: str,
    whisper_model_dir: Optional[str],
    ja_initial_prompt: Optional[str],
) -> Optional[Path]:
    """
    Transcribe using local Whisper (openai-whisper) or faster-whisper.
    Returns path to transcript.txt or None.
    Also emits transcript.srt / transcript.vtt when possible.
    """
    transcript_txt = out_dir / "transcript.txt"
    srt_path = out_dir / "transcript.srt"
    vtt_path = out_dir / "transcript.vtt"

    try:
        backend = asr_backend
        if backend == "auto":
            if _HAVE_FASTER:
                backend = "faster"
            elif _HAVE_OPENAI_WHISPER:
                backend = "whisper"
            else:
                print("[WARN] No ASR backend installed. Skipping transcription.", file=sys.stderr)
                return None

        if backend == "faster":
            if not _HAVE_FASTER:
                print("[WARN] faster-whisper not installed; falling back to Whisper if available.", file=sys.stderr)
                backend = "whisper"

        if backend == "whisper":
            if not _HAVE_OPENAI_WHISPER:
                print("[WARN] openai-whisper not installed; falling back to faster-whisper if available.", file=sys.stderr)
                backend = "faster"

        if backend == "whisper":
            import whisper  # type: ignore
            dev = get_device(device)
            # download_root でモデルの保存先を指定可能（例: --model-dir "C:/whisper_models"）
            model = whisper.load_model(whisper_model, device=dev, download_root=whisper_model_dir)

            # 句読点が出にくい場合は initial_prompt を与えると改善することがあります（日本語推奨）
            transcribe_opts = dict(
                language=lang,
                task="transcribe",
                condition_on_previous_text=True,
                fp16=(dev == "cuda"),
            )
            if ja_initial_prompt:
                transcribe_opts["initial_prompt"] = ja_initial_prompt

            result = model.transcribe(str(audio_path), **transcribe_opts)
            text = (result.get("text") or "").strip()
            transcript_txt.write_text(text + "\n", encoding="utf-8")

            # segments -> SRT/VTT
            segments = result.get("segments") or []
            try:
                write_srt(segments, srt_path, from_whisper=True)
                write_vtt(segments, vtt_path, from_whisper=True)
            except Exception as e:
                print(f"[WARN] Failed to write SRT/VTT: {e}", file=sys.stderr)

            return transcript_txt

        elif backend == "faster":
            from faster_whisper import WhisperModel  # type: ignore
            dev = get_device(device)
            compute_type = "float16" if dev == "cuda" else "int8"
            model = WhisperModel(whisper_model, device=dev, compute_type=compute_type)

            segments_iter, info = model.transcribe(
                str(audio_path),
                language=lang,
                vad_filter=True,
                beam_size=5,
            )

            # Collect for files
            all_text_lines = []
            segments_list = []
            for seg in segments_iter:
                t = seg.text.strip()
                if t:
                    all_text_lines.append(t)
                segments_list.append(seg)

            transcript_txt.write_text("\n".join(all_text_lines) + "\n", encoding="utf-8")

            # SRT/VTT
            try:
                write_srt(segments_list, srt_path, from_whisper=False)
                write_vtt(segments_list, vtt_path, from_whisper=False)
            except Exception as e:
                print(f"[WARN] Failed to write SRT/VTT: {e}", file=sys.stderr)

            return transcript_txt

        else:  # none
            return None

    except Exception as e:
        print(f"[WARN] Transcription failed: {e}", file=sys.stderr)
        return None


# ---------- Markdown ----------
def write_markdown(lecture_name: str, out_dir: Path, slides_json: Dict, slide_pngs: List[str], transcript_path: Optional[Path]) -> Path:
    md_path = out_dir / f"{lecture_name}.md"
    lines = []
    lines.append(f"# {lecture_name}: {slides_json.get('source', 'slides.pptx')}")
    lines.append("")

    slides: List[Dict] = slides_json["slides"]
    ordered_pngs = sorted([p for p in slide_pngs if Path(p).suffix.lower() in {'.png', '.jpg', '.jpeg'}])
    png_by_index = {i+1: ordered_pngs[i] for i in range(len(ordered_pngs))}

    for i, s in enumerate(slides, start=1):
        title = s.get("title") or f"Slide {i}"
        lines.append(f"## Slide {i} — {title}")
        lines.append("")
        bullets = s.get("bullets") or []
        if bullets:
            for b in bullets:
                lines.append(f"- {b}")
            lines.append("")

        notes = (s.get("notes") or "").strip()
        if notes:
            lines.append("**Notes**")
            lines.append("")
            for line in notes.splitlines():
                lines.append(f"> {line}")
            lines.append("")

        if i in png_by_index:
            lines.append(f"![Slide {i}]({png_by_index[i]})")
            lines.append("")

        for img_rel in s.get("images", []):
            lines.append(f"![Embedded]({img_rel})")
            lines.append("")

        lines.append("---")
        lines.append("")

    if transcript_path and transcript_path.exists():
        lines.append("## Transcript (Appendix)")
        lines.append("")
        try:
            txt = transcript_path.read_text(encoding="utf-8", errors="ignore")
        except Exception:
            txt = transcript_path.read_text(errors="ignore")
        lines.append(txt.strip())
        lines.append("")

    # FIX: must join with newlines
    md_path.write_text("\n".join(lines), encoding="utf-8")
    return md_path


# ---------- Main ----------
def main():
    ap = argparse.ArgumentParser(description="Build textbook (Markdown) from PPTX + audio/video.")
    ap.add_argument("--input-root", required=True, help="Root directory containing per-lecture folders")
    ap.add_argument("--output-root", required=True, help="Output directory")
    ap.add_argument("--lang", default="ja", help="Transcription language code (e.g., ja, en)")
    ap.add_argument("--whisper-model", default="small", help="Whisper model name (small/medium/large-v3 etc.)")

    # ASR options
    ap.add_argument("--asr", choices=["auto", "whisper", "faster", "none"], default="auto",
                    help="ASR backend: local Whisper (openai-whisper), faster-whisper, auto, or none")
    ap.add_argument("--device", choices=["auto", "cpu", "cuda"], default="auto",
                    help="Device for ASR backends")
    ap.add_argument("--model-dir", default=None,
                    help="Directory to cache/download Whisper models (for openai-whisper)")
    ap.add_argument("--ja-initial-prompt", default=None,
                    help="Initial prompt for Whisper (e.g., '句読点を正しく挿入してください。'). Helpful for Japanese punctuation.")

    args = ap.parse_args()

    in_root = Path(args.input_root).expanduser().resolve()
    out_root = Path(args.output_root).expanduser().resolve()
    out_root.mkdir(parents=True, exist_ok=True)

    lectures = find_lecture_folders(in_root)
    if not lectures:
        print(f"[ERR] No lecture folders with .pptx found in: {in_root}", file=sys.stderr)
        sys.exit(1)

    book_paths = []
    for lec_dir in lectures:
        lec_name = lec_dir.name
        lec_out = out_root / lec_name
        lec_out.mkdir(parents=True, exist_ok=True)

        pptx = find_first(lec_dir, {'.pptx'})
        audio = find_first(lec_dir, AUDIO_EXTS)
        if pptx is None:
            print(f"[WARN] {lec_name}: no PPTX found, skipping.", file=sys.stderr)
            continue

        data = extract_pptx(pptx, lec_out)
        data["source"] = pptx.name
        (lec_out / "slides_text.json").write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8")

        slide_pngs = copy_slides_png_if_exists(lec_dir, lec_out)

        transcript = None
        if audio and args.asr != "none":
            transcript = transcribe_audio(
                audio_path=audio,
                out_dir=lec_out,
                lang=args.lang,
                whisper_model=args.whisper_model,
                asr_backend=args.asr,
                device=args.device,
                whisper_model_dir=args.model_dir,
                ja_initial_prompt=args.ja_initial_prompt,
            )

        md = write_markdown(lec_name, lec_out, data, slide_pngs, transcript)
        book_paths.append(md)
        print(f"[OK] Built chapter: {md}")

    book_md = out_root / "book.md"
    with open(book_md, "w", encoding="utf-8") as fw:
        for p in sorted(book_paths):
            fw.write(p.read_text(encoding="utf-8"))
            fw.write("\n\n\n")
    print(f"[OK] Combined book: {book_md}")


if __name__ == "__main__":
    main()
