ai.guess_speaker のソースコード

"""
テキストファイルから話者を推測し、整理された形式で出力するスクリプト。

このスクリプトは、入力された会話テキストを解析し、OpenAIのGPTモデルとFunction Calling機能を使用して
会話内の話者を識別し、各発言を話者ごとに整理して出力ファイルに保存します。
主に半導体・材料科学分野の講義テキストにおける誤字・誤訳の訂正、改行、話者推測を目的としています。

:doc:`guess_speaker_usage`
"""
import os
import sys
import json
try:
    from dotenv import load_dotenv
except:
    print(f"\nImport error: dotenv")
    input("Install: pip install dotenv\n")
    exit()
try:
    import openai
    from openai import OpenAI
except:
    print(f"\nImport error: openai")
    input("Install: pip install openai\n")
    exit()


config_path = "translate.env"
input_path = "transcript.txt"
output_path = ""
guess_speakers = 1

max_bytes = 20000
max_tokens = 2000

prompt = """
あなたは半導体・材料科学の教授です。
以下の講義の会話テキストには誤字・誤訳が多くあります。次の作業をしてください。
1.誤字・誤訳を訂正する
2.読みやすいように文ごとに改行を入れる
3.話者を推測する
"""


if not os.path.isfile(config_path):
    script_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(script_dir, config_path)

print()
# 環境変数読み込み
if os.path.isfile(config_path):
    print(f"config_path: {config_path}")
else:
    print(f"Warning: config_path {config_path} is not found")
load_dotenv(dotenv_path=config_path)

account_inf_path = os.getenv("account_inf_path", "accounts.env")
if os.path.isfile(account_inf_path):
    print(f"account_inf_path: {account_inf_path}")
else:
    print(f"Warning: account_inf_path {account_inf_path} is not found")
load_dotenv(dotenv_path=account_inf_path)

# APIキー設定
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None or api_key == "":
    print(f"Error: api_key is not found.")
    exit()

client = OpenAI(api_key=api_key)

openai_model = os.getenv("openai_model", "gpt-4o")
temperature = float(os.getenv("temperature", "0.3"))
max_tokens = int(os.getenv("max_tokens", max_tokens))


# speaker guess用Function Call関数定義
functions = [
    {
        "name": "guess_speakers",
        "description": "会話のテキストから話者を推測し、JSON 構造で返す",
        "parameters": {
            "type": "object",
            "properties": {
                "speakers": {
                    "type": "array",
                    "items": {"type": "string"}
                },
                "turns": {
                    "type": "array",
                    "items": {
                        "type": "object",
                        "properties": {
                            "speaker": {"type": "string"},
                            "text": {"type": "string"}
                        },
                        "required": ["speaker", "text"]
                    }
                }
            },
            "required": ["speakers", "turns"]
        }
    }
]




[ドキュメント] def text_with_speakers(json_str: str): """ JSON文字列から話者とテキストを抽出し、整形された単一の文字列として返します。 与えられたJSON文字列を解析し、会話の各ターン(発言)について 話者名と発言内容を "[話者名]: テキスト\n" の形式で連結し、一つの文字列として返します。 :param json_str: 話者情報と会話ターンを含むJSON形式の文字列。 :type json_str: str :returns: 整形された会話テキスト。 :rtype: str """ data = json.loads(json_str) text_ret = "" for turn in data.get("turns", []): speaker = turn.get("speaker", "Unknown") text = turn.get("text", "") text_ret += f"[{speaker}]: {text}\n" return text_ret
[ドキュメント] def save_with_speakers(json_str: str, outfile: str): """ JSON文字列に含まれる話者ごとの発言をファイルに保存します。 与えられたJSON文字列を解析し、その中の "turns" 配列から 各発言を "[話者名]: 発言テキスト" の形式で指定されたファイルに書き出します。 :param json_str: 話者情報と会話ターンを含むJSON形式の文字列。 :type json_str: str :param outfile: 出力先のファイルパス。 :type outfile: str :returns: None :rtype: None """ data = json.loads(json_str) with open(outfile, "w", encoding="utf-8") as f: for turn in data.get("turns", []): speaker = turn.get("speaker", "Unknown") text = turn.get("text", "") f.write(f"[{speaker}]: {text}\n")
[ドキュメント] def save_text_to_file(filename: str, content: str): """ 指定された内容をファイルに保存します。 指定されたファイル名でファイルを開き、提供された内容を書き込みます。 書き込みに成功した場合はメッセージを表示し、失敗した場合はエラーメッセージを表示します。 :param filename: 保存するファイルのパス。 :type filename: str :param content: ファイルに書き込む内容。 :type content: str :returns: None :rtype: None """ try: with open(filename, "w", encoding="utf-8") as f: f.write(content) print(f"'{filename}' に内容を保存しました。") except IOError as e: print(f"ファイルの保存中にエラーが発生しました: {e}")
[ドキュメント] def main(): """ スクリプトの主要な処理を実行します。 入力テキストファイルを読み込み、環境変数やコマンドライン引数から設定をロードします。 その後、OpenAIのGPTモデルに対し、Function Callingを使用して話者推測のリクエストを送信します。 GPTからの応答を解析し、話者情報を含む整形されたテキストを標準出力と指定された出力ファイルに保存します。 """ global output_path print() print("Guess speakers from text file") output_dir = os.path.dirname(os.path.abspath(input_path)) or "." file_body = os.path.splitext(os.path.basename(input_path))[0] #input_pathの存在を確認したのち、全テキストをtranscriptに読み込む if os.path.exists(input_path): with open(input_path, "r", encoding="utf-8") as f: transcript = f.read() else: print(f"\nError: Could not read [{input_path}]\n") exit() if output_path == "": output_path = os.path.join(output_dir, f"{file_body}_with_speakers.txt") print() print(f"api_key: {api_key}") print(f"input_path : {input_path}") print(f"output_path : {output_path}") print(f"guess_speakers: {guess_speakers}") print(f"openai_model: {openai_model}") print(f"temperature : {temperature}") print(f"max_tokens : {max_tokens}") print(f"max_bytes : {max_bytes}") if True: # GPT に Function Calling 形式でリクエスト print() print("Sending request to ChatGPT:") response = client.chat.completions.create( model="gpt-4o", messages=[ {"role": "system", "content": prompt}, {"role": "user", "content": transcript[:max_bytes]} ], functions=functions, function_call={"name": "guess_speakers"}, temperature=0.0 ) # 関数呼び出しの arguments 部分を JSON としてロード try: arguments = response.choices[0].message.function_call.arguments result = json.loads(arguments) except Exceptions as e: # このExceptionsはExceptionのタイプミスだが、既存コード変更ルールのため修正しない print(f"Error: JSONの解析中にエラーが起こりました: {e}") # print("0=", response.choices[0]) print("response.choices[0].message: ", response.choices[0].message) print("response.choices[0].message.function_call:", response.choices[0].message.function_call) exit() # base = os.path.splitext(os.path.basename(audio_path))[0] # out_file = os.path.join(os.path.dirname(os.path.abspath(audio_path)), f"{base}_guessed.json") print_speakers(arguments) save_with_speakers(arguments, output_path) """ # JSON 出力 with open(out_file, "w", encoding="utf-8") as f: json.dump(result, f, ensure_ascii=False, indent=2) """ print() print(f"話者推測結果を '{output_path}' に保存しました。") input("\nPress ENTER to terminate>>\n")
if __name__ == "__main__": argv = sys.argv nargs = len(argv) if nargs > 1: input_path = argv[1] if nargs > 2: output_path = argv[2] if nargs > 3: max_bytes = int(argv[3]) main()