#!/usr/bin/env python3
# -*- coding: utf-8 -*-

#pip install elevenlabs
#export ELEVENLABS_API_KEY="あなたのAPIキー"
"""
python tts_elevenlabs.py voices
python tts_elevenlabs.py tts -t "本日は実験結果についてご説明いたします。" -o out.mp3
python tts_elevenlabs.py stream -t "これから理論的背景について説明します。" -o out_stream.mp3
python tts_elevenlabs.py tts -t "こんにちは" --voice "Adam" -o adam.mp3
"""

"""
ElevenLabs TTS runnable demo
- List voices
- Convert text to speech and save to file
- Streaming TTS and save to file (byte stream)

Requirements:
  pip install elevenlabs
Env:
  ELEVENLABS_API_KEY=...
"""

import os
import sys
import argparse
import re
from typing import Optional, Dict, Any, Tuple, List

DEFAULT_MODEL_ID = "eleven_multilingual_v2"  # 日本語含む多言語モデル
DEFAULT_VOICE_NAME = "Rachel"               # なければ自動で先頭にフォールバック
DEFAULT_VOICE_SETTINGS = {
    "stability": 0.60,          # 落ち着いた読み上げ寄り
    "similarity_boost": 0.75,   # 声質の一貫性
}

def _import_client():
    """
    elevenlabs SDK の import をまとめる
    """
    try:
        from elevenlabs import ElevenLabs  # 新しめのSDKでの標準
        return ElevenLabs
    except Exception as e:
        raise ImportError(
            "elevenlabs の import に失敗しました。"
            " `pip install elevenlabs` を実行し、環境を確認してください。\n"
            f"詳細: {e}"
        )

def get_client(api_key: Optional[str] = None):
    ElevenLabs = _import_client()
    if api_key:
        return ElevenLabs(api_key=api_key)
    return ElevenLabs()  # env: ELEVENLABS_API_KEY を参照

def get_voices(client) -> List[Dict[str, Any]]:
    """
    voice 一覧を dict の list で返す
    """
    res = client.voices.get_all()
    voices = []
    for v in getattr(res, "voices", []) or []:
        voices.append({
            "name": getattr(v, "name", ""),
            "voice_id": getattr(v, "voice_id", None) or getattr(v, "id", None),
            "category": getattr(v, "category", None),
            "labels": getattr(v, "labels", None),
        })
    return voices

def pick_voice_id(voices: List[Dict[str, Any]], voice: str) -> str:
    """
    voice が voice_id っぽければそれを返す。
    そうでなければ name の完全一致→部分一致で探す。
    見つからなければ DEFAULT_VOICE_NAME を探し、それも無ければ先頭。
    """
    if not voices:
        raise RuntimeError("voice 一覧が空です（APIキーや権限を確認してください）")

    # voice_id らしき文字列（長めの英数/ハイフン/アンダースコア）ならそのまま
    if re.fullmatch(r"[A-Za-z0-9_-]{10,}", voice or ""):
        return voice

    want = (voice or "").strip().lower()
    if want:
        for v in voices:
            if (v["name"] or "").lower() == want:
                return v["voice_id"]
        for v in voices:
            if want in (v["name"] or "").lower():
                return v["voice_id"]

    # デフォルト名を試す
    d = DEFAULT_VOICE_NAME.lower()
    for v in voices:
        if (v["name"] or "").lower() == d:
            return v["voice_id"]

    # 先頭フォールバック
    return voices[0]["voice_id"]

def _write_audio_to_file(audio, outpath: str) -> int:
    """
    audio が bytes でも generator(=chunk列) でもファイルに書けるようにする
    戻り値: 書き込んだ総バイト数
    """
    total = 0
    with open(outpath, "wb") as f:
        # bytes / bytearray
        if isinstance(audio, (bytes, bytearray)):
            f.write(audio)
            total = len(audio)
            return total

        # generator / iterable of chunks
        for chunk in audio:
            if not chunk:
                continue
            f.write(chunk)
            total += len(chunk)
    return total


def tts_convert_to_file(
    client,
    text: str,
    outpath: str,
    voice: str = DEFAULT_VOICE_NAME,
    model_id: str = DEFAULT_MODEL_ID,
    voice_settings: Optional[Dict[str, Any]] = None,
) -> str:
    voices = get_voices(client)
    voice_id = pick_voice_id(voices, voice)

    vs = dict(DEFAULT_VOICE_SETTINGS)
    if voice_settings:
        vs.update(voice_settings)

    audio = client.text_to_speech.convert(
        text=text,
        voice_id=voice_id,
        model_id=model_id,
        voice_settings=vs
    )

    total = _write_audio_to_file(audio, outpath)
    if total <= 0:
        raise RuntimeError("音声データが空でした（API応答/クォータ/モデル指定を確認）")

    return outpath

def tts_stream_to_file(
    client,
    text: str,
    outpath: str,
    voice: str = DEFAULT_VOICE_NAME,
    model_id: str = DEFAULT_MODEL_ID,
    voice_settings: Optional[Dict[str, Any]] = None,
) -> Tuple[str, int]:
    """
    streaming で受け取りつつファイルに保存する
    戻り値: (outpath, total_bytes)
    """
    voices = get_voices(client)
    voice_id = pick_voice_id(voices, voice)

    vs = dict(DEFAULT_VOICE_SETTINGS)
    if voice_settings:
        vs.update(voice_settings)

    stream_iter = client.text_to_speech.stream(
        text=text,
        voice_id=voice_id,
        model_id=model_id,
        voice_settings=vs
    )

    total = 0
    with open(outpath, "wb") as f:
        for chunk in stream_iter:
            if not chunk:
                continue
            f.write(chunk)
            total += len(chunk)

    return outpath, total

def main():
    print()
    print("API_KEY=", os.environ['ELEVENLABS_API_KEY'])
    print()

    p = argparse.ArgumentParser(description="ElevenLabs TTS demo (runnable)")
    sub = p.add_subparsers(dest="cmd", required=True)

    p_voices = sub.add_parser("voices", help="List available voices")

    p_tts = sub.add_parser("tts", help="Convert text to speech and save to file")
    p_tts.add_argument("-t", "--text", required=True, help="Text to speak")
    p_tts.add_argument("-o", "--out", default="output.mp3", help="Output file path (e.g. output.mp3)")
    p_tts.add_argument("--voice", default=DEFAULT_VOICE_NAME, help="Voice name or voice_id")
    p_tts.add_argument("--model", default=DEFAULT_MODEL_ID, help="Model id")
    p_tts.add_argument("--stability", type=float, default=DEFAULT_VOICE_SETTINGS["stability"])
    p_tts.add_argument("--similarity", type=float, default=DEFAULT_VOICE_SETTINGS["similarity_boost"])

    p_stream = sub.add_parser("stream", help="Streaming TTS and save to file")
    p_stream.add_argument("-t", "--text", required=True, help="Text to speak")
    p_stream.add_argument("-o", "--out", default="output_stream.mp3", help="Output file path")
    p_stream.add_argument("--voice", default=DEFAULT_VOICE_NAME, help="Voice name or voice_id")
    p_stream.add_argument("--model", default=DEFAULT_MODEL_ID, help="Model id")
    p_stream.add_argument("--stability", type=float, default=DEFAULT_VOICE_SETTINGS["stability"])
    p_stream.add_argument("--similarity", type=float, default=DEFAULT_VOICE_SETTINGS["similarity_boost"])

    p.add_argument("--api-key", default=None, help="API key (optional). If omitted, uses ELEVENLABS_API_KEY env var.")

    args = p.parse_args()

    try:
        client = get_client(args.api_key)
    except Exception as e:
        print(f"❌ クライアント初期化に失敗: {e}")
        sys.exit(1)

    if args.cmd == "voices":
        try:
            voices = get_voices(client)
        except Exception as e:
            print(f"❌ voice一覧の取得に失敗: {e}")
            sys.exit(1)

        if not voices:
            print("(no voices found)")
            sys.exit(0)

        print("=== Voices ===")
        for v in voices:
            print(f"- {v['name']}  (id={v['voice_id']})  category={v.get('category')} labels={v.get('labels')}")
        return

    if args.cmd == "tts":
        vs = {"stability": args.stability, "similarity_boost": args.similarity}
        try:
            out = tts_convert_to_file(
                client=client,
                text=args.text,
                outpath=args.out,
                voice=args.voice,
                model_id=args.model,
                voice_settings=vs,
            )
            print(f"✅ saved: {out}")
        except Exception as e:
            print(f"❌ TTS失敗: {e}")
            sys.exit(1)
        return

    if args.cmd == "stream":
        vs = {"stability": args.stability, "similarity_boost": args.similarity}
        try:
            out, total = tts_stream_to_file(
                client=client,
                text=args.text,
                outpath=args.out,
                voice=args.voice,
                model_id=args.model,
                voice_settings=vs,
            )
            print(f"✅ saved (stream): {out}  bytes={total}")
        except Exception as e:
            print(f"❌ Streaming TTS失敗: {e}")
            sys.exit(1)
        return

if __name__ == "__main__":
    main()
