import argparse
import os
import shutil
import glob
import io
try:
    from dotenv import load_dotenv
except:
    print("Error: Failed to import dotenv")
    input("Install: pip install dotenv")
try:
    from openpyxl import Workbook, load_workbook
except:
    print("Error: Failed to import openpyxl")
    input("Install: pip install openpyxl")
try:
    import PyPDF2
except:
    print("Error: Failed to import PyPDF2")
    input("Install: pip install pypdf2")

config_path = "translate.env"
if not os.path.isfile(config_path):
    script_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(script_dir, config_path)

DEFAULT_TEMPLATE = "(author_first)_(author_last)_(short_title)_(shortest_name)_(year).pdf"

#=========================================================
# Prompt template: start
#=========================================================
prompt_template = """
以下の学術論文のテキストから、指定された情報をJSON形式で抽出してください。
また、first author、last author、短縮したtitle (short_title)、shortest_name、発行年から
推奨ファイル名 {{template}} を作ってください。
short_titleの単語はCaptalFirstにして、空白、.、,などの文字は削除してください。

推奨ファイル名（filename_rename）:
著者リスト（authors）   ：
第一著者（author_first）：
最終著者（author_last） ：
論文題目（title）           :
論文題目略称（short_title） :
発行雑誌名（journal）       ：
発行雑誌名略称（short_name）：
発行雑誌名の最短の略称（shortest_name）：
発行年（year）   :
巻番号（volume）：
号番号（issue） ：
ページ（pages ）：
DOI（doi）：

情報が不明な場合は、空文字列 "" を使用してください。

--- 論文テキスト ---
{{text}}
"""
# end

load_dotenv(dotenv_path = config_path)
account_inf_path = os.getenv("account_inf_path", "accounts.env")
load_dotenv(dotenv_path = account_inf_path)

openai_api_key = os.getenv("OPENAI_API_KEY")
openai_model = os.getenv("openai_model", "gpt-4o")
temperature = float(os.getenv("temperature", "0.3"))
max_tokens  = int(os.getenv("max_tokens", "2000"))

google_api_key = os.getenv("GOOGLE_API_KEY")
gemini_model = os.getenv("gemini_model", "gemini-2.0-flash")


def initialize():
    parser = argparse.ArgumentParser(description="学術論文のPDFからメタデータを抽出します。")
    parser.add_argument("--api", type=str, default='openai',
                        help="生成AIのAPI: openai|google")
    parser.add_argument("input_file", type=str, help="処理するPDFファイルのパス")
    parser.add_argument("--summary_path", type=str, default = "summary.xlsx", 
                        help="summary.xlsxのパス")
    parser.add_argument("--recursive", type=int, default=0, 
                        help="1の場合、サブディレクトリまで処理する")
    parser.add_argument("--rename", type=int, default=0, 
                        help="1の場合、生成AIで推薦ファイル名を取得できた場合、リネームする")
#    parser.add_argument("--rename", action="store_true", 
#                        help="生成AIで推薦ファイル名を取得できた場合、リネームする")
    parser.add_argument("--delete_original", type=int, default=1, 
                        help="1の場合、renameしたら元のファイルを削除する")
    parser.add_argument("--max_bytes", type=int, default=10000,
                        help="APIに送る最大テキストバイト数。論文全体を送る必要がない場合に指定。")
    parser.add_argument("--template", type=str, default=DEFAULT_TEMPLATE,
                        help="ファイル名を変える際のテンプレート {{DEFAULT_TEMPLATE}}")
    return parser

# --- PDFからテキストを抽出する関数 ---
def extract_text_from_pdf(pdf_path: str, max_bytes: int) -> str:
    """
    PDFファイルからテキストを抽出し、指定された最大バイト数に制限する。
    """
    text_buffer = io.StringIO()
    try:
        with open(pdf_path, 'rb') as f:
            pdf_reader = PyPDF2.PdfReader(f)
            num_pages = len(pdf_reader.pages)
            
            for page_num in range(num_pages):
                page = pdf_reader.pages[page_num]
                text_buffer.write(page.extract_text() or '')
                
                # バイト数チェック
                if text_buffer.tell() > max_bytes:
                    break
    except Exception as e:
        print(f"PDFファイルの読み込み中にエラーが発生しました: {e}")
        return ""

    full_text = text_buffer.getvalue()
    
    # バイト数でテキストを切り詰める
    encoded_text = full_text.encode('utf-8', 'ignore')
    if len(encoded_text) > max_bytes:
        truncated_text = encoded_text[:max_bytes].decode('utf-8', 'ignore')
        return truncated_text
    
    return full_text

# --- APIに送信してメタデータを抽出する関数 ---

def get_metadata_from_openai(text: str, template: str) -> dict:
    """
    OpenAI APIを使用して論文のメタデータを抽出する。
    """
    try:
        import openai
        openai.api_key = openai_api_key
        if not openai.api_key:
            raise ValueError("OPENAI_API_KEY が取得できません。")

        prompt = prompt_template.replace("{{text}}", text).replace("{{template}}", template)
        response = openai.chat.completions.create(
            model=openai_model,
            messages=[
                {"role": "user", "content": prompt}
            ],
            response_format={"type": "json_object"}
        )
        json = eval(response.choices[0].message.content)
        if type(json) is list or type(json) is tuple:
            return json[0]
        else:
            return json

    except ImportError:
        print("openai ライブラリがインストールされていません。`pip install openai` を実行してください。")
    except ValueError as e:
        print(f"APIキーエラー: {e}")
    except Exception as e:
        print(f"OpenAI APIの呼び出し中にエラーが発生しました: {e}")
        print("  response:", response)
    return {}

def get_metadata_from_google(text: str, template: str) -> dict:
    """
    Google Gemini APIを使用して論文のメタデータを抽出する。
    """

    try:
        import google.generativeai as genai
        if not google_api_key:
            raise ValueError("GOOGLE_API_KEY' が取得できません。")

        genai.configure(api_key = google_api_key)
            
        model = genai.GenerativeModel(gemini_model)
        
        prompt = prompt_template.replace("{{text}}", text).replace("{{template}}", template)
        response = model.generate_content(
            prompt,
            generation_config={"response_mime_type": "application/json"}
        )
        json = eval(response.text)
        if type(json) is list or type(json) is tuple:
            return json[0]
        else:
            return json
    except ImportError:
        print("google-generativeai ライブラリがインストールされていません。`pip install google-generativeai` を実行してください。")
    except ValueError as e:
        print(f"APIキーエラー: {e}")
    except Exception as e:
        print(f"Google Gemini APIの呼び出し中にエラーが発生しました: {e}")
        print("  response:", response.text)
    return {}

def append_xlsx(path, mode, labels, data_list):
    if mode == "w" or not os.path.exists(path):
        wb = Workbook()
        ws = wb.active
        ws.append(labels)
    else:
        wb = load_workbook(filename = path)
        ws = wb.active

    for row in data_list:
        ws.append(row)

    wb.save(path)

def get_inf(input_file, summary_path, api = "openai", max_bytes = 10000, 
            rename = False, delete_original = True, template = DEFAULT_TEMPLATE):
    directory = os.path.dirname(input_file)
    filename  = os.path.basename(input_file)
    print(f"Input PDF file: [{input_file}]")
    print(f"  Directory: [{directory}]")
    print(f"  File name: [{filename}]")

    extracted_text = extract_text_from_pdf(input_file, max_bytes)
    if not extracted_text:
        print("テキストの抽出に失敗したか、ファイルが空です。処理を終了します。")
        print("PyCryptodomeが必要というメッセージが出たら、以下のようにinstallしてください。")
        print("Install: pip install pycryptodome ")
        return False

    labels = ["directory", "filename_original", "filename_rename", 
              "authors", "author_first", "author_last", 
              "title", "short_title", "journal", "short_name", "shortest_name", 
              "year", "volume", "issue", "pages", "doi"]

    print(f"  API: {api.upper()}")
    print(f"  max bytes: {max_bytes}")
    print(f"  抽出したテキストを送信中...")
    if api == 'openai':
        metadata = get_metadata_from_openai(extracted_text[:max_bytes], template)
    elif api == 'google':
        metadata = get_metadata_from_google(extracted_text[:max_bytes], template)
    else:
        print("Error in get_inf(): api が 'openai' または 'google' ではありません。")
        return None

    if metadata:
        print("\n--- 論文メタデータ ---")
        data = [directory, filename]
        for key, value in metadata.items():
            print(f"{key}: {value}")
            if type(value) is list or type(value) is tuple:
                data.append(", ".join(value))
            else:
                data.append(value)
    else:
        data = [directory, filename, "", "", "", "", "", 
                "", "", "", "", "", 
                "", "", "", "", ""]

    append_xlsx(summary_path, "a", labels, [data])

    if rename and metadata.get("filename_rename", None):
        new_filename = metadata["filename_rename"]
        if not new_filename.endswith('.pdf'):
            new_filename += '.pdf'

        if os.path.exists(new_filename):
            print(f"ファイル名 '{new_filename}' は既に存在します。")
            return False

        try:
            directory = os.path.dirname(input_file)
            os.chdir(directory)
            if delete_original:
                os.rename(input_file, new_filename)
                print(f"ファイル名を '{new_filename}' に変更しました。")
            else:
                shutil.copy(input_file, new_filename)
                print(f"'{input_file}' を '{new_filename}' にコピーしました。")
        except Exception as e:
            print(f"ファイル名の変更中にエラーが発生しました: {e}")

    return True
    
def main():
    parser = initialize()
    args = parser.parse_args()
    
    fmask = args.input_file
    if args.recursive:
        directory = os.path.dirname(fmask)
        base_name = os.path.basename(fmask)
        print("d=", directory)
        print("b=", base_name)
        if '*' not in directory:
            fmask = os.path.join(directory, '**', base_name)

    print()
    print(f"Search files for [{fmask}]")
    print(f"Recursive search? {args.recursive}")
    files = sorted(glob.glob(fmask, recursive = args.recursive))
    print()
    if len(files) == 0:
        print(f"Error: No file found.")
    else:
        print(f"Files found for [{args.input_file}]")
        for f in files:
            print(f"  {f}")

        for f in files:
            print()
            get_inf(f, args.summary_path, args.api, args.max_bytes, args.rename, args.delete_original, args.template)

    input("\nPress ENTER to terminate>>\n")

if __name__ == "__main__":
    main()
