"""
Fermi積分関数のテストもお願いします。j=[-1, -.5, 0, .5, 1, 1.5, 2]とetaを振ってください。
def Fj(j, eta):
    integrand = lambda x: x**j * expit(eta - x)
    res, _ = quad(integrand, 0, max(eta + 50, 100))
    return res
"""


import numpy as np
from scipy.special import expit
from scipy.integrate import quad
import time

from tkfermi_integral import FermiIntegral_fast


# -----------------------------
# あなたの expit 版 Fermi 積分
# -----------------------------
def Fj_expit(j, eta):
    integrand = lambda x: x**j * expit(eta - x)
    upper = max(eta + 50, 100)   # 十分大きい積分上限
    res, _ = quad(integrand, 0, upper)
    return res

# -----------------------------
# 生の Fermi–Dirac 版
# -----------------------------
def Fj_raw(j, eta):
    integrand = lambda x: x**j / (1.0 + np.exp(x - eta))
    upper = max(eta + 50, 100)
    res, _ = quad(integrand, 0, upper)
    return res

# -----------------------------
# テスト設定
# -----------------------------
j_values = [-0.5, 0, 0.5, 1, 1.5, 2]
eta_values = np.linspace(-10, 10, 9)

tol = 1e-10   # 許容誤差

print("Testing Fermi integrals Fj(j, eta)")
print("diff tolerance:", tol)
print("--------------------------------------------------------------")
print("   j     eta      values     diff            result")
print("--------------------------------------------------------------")

for j in j_values:
    for eta in eta_values:
        val1 = Fj_expit(j, eta)
        val2 = Fj_raw(j, eta)
        diff = abs(val1 - val2)
        result = "PASS" if diff < tol else "FAIL"

        val3 = FermiIntegral_fast(eta, j)
        diff = abs(val1 - val3)
        result = "PASS" if diff < tol and result else "FAIL"

        print(f"{j:5.1f}  {eta:6.2f}   {val1:12.5e} - {val2:12.5e} - {val1:13.5e} = {diff:12.5e}   {result}")
    print("--------------------------------------------------------------")


# ==============================================================
#  ベンチマーク
# ==============================================================
# 適当な j
j_test = 0.5
N = 20  # 繰り返し回数（小さめ）

print("\nBenchmarking Fj implementations")
print(f"j_test = {j_test}")
print("--------------------------------------------------------------")
for eta_test in [-1.0, -0.1, 0.0, 0.1, 0.5, 2.0]:
# --- expit 版 ---
    t0 = time.perf_counter()
    for _ in range(N):
        Fj_expit(j_test, eta_test)
    t1 = time.perf_counter()
    time_expit = (t1 - t0) / N

# --- raw 版 ---
    t0 = time.perf_counter()
    for _ in range(N):
        Fj_raw(j_test, eta_test)
    t1 = time.perf_counter()
    time_raw = (t1 - t0) / N

# --- fast 版 ---
    t0 = time.perf_counter()
    for _ in range(N):
        FermiIntegral_fast(eta_test, j_test)
    t1 = time.perf_counter()
    time_fast = (t1 - t0) / N
    
    print(f"eta_test = {eta_test}")
    print(f"expit version: {time_expit*1e3:8.3f} ms per call")
    print(f"raw  version:  {time_raw*1e3:8.3f} ms per call")
    print(f"fast version:  {time_fast*1e6:8.3f} µs per call")
    print("--------------------------------------------------------------")

