"""
fermi_dirac_e_npvector, fermi_dirac_e_if, fermi_dirac_eを
ライブラリtkfermi_integral.pyから読み込み、x = [-100, -50, -40, -20, -10, -5, -2, -1, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0, 2, 5, 10, 20, 40, 50, 100]の値と相対誤差を計算し、
tol=1e-20を基準にしてテストして、FAILかPASSかを出力（出力には３関数の値とdiff値もいれる）するとともに、
ベンチマークするプログラムを作ってください
"""


import numpy as np
import time
from tkfermi_integral import (
    fermi_dirac_e_npvector,
    fermi_dirac_e_if,
    fermi_dirac_e
)

# ------------------------------------------------------------
# Test points
# ------------------------------------------------------------
x_values = np.array([
    -1000,
    -100, -50, -40, -20, -10, -5, -2, -1, -0.5, -0.1,
     0.0, 0.1, 0.5, 1.0, 2, 5, 10, 20, 40, 50, 100,
     300, 500, 600, 700, 800, 900, 1000
], dtype=float)

eta = 0.0
g = 1.0
tol = 1e-20

# ------------------------------------------------------------
# Helper: relative error
# ------------------------------------------------------------
def relerr(a, b):
    return np.abs(a - b) / np.maximum(np.abs(b), 1e-300)

# ------------------------------------------------------------
# Wrappers to unify interface
# ------------------------------------------------------------
def fd_high_precision_scalar(x, eta, g):
    return fermi_dirac_e_npvector(np.array([x]))[0]

def fd_if_scalar(x, eta, g):
    return fermi_dirac_e_if(x, eta, g)

def fd_expit_scalar(x, eta, g):
    return fermi_dirac_e(x, eta, g)

# ------------------------------------------------------------
# Test each function
# ------------------------------------------------------------
def test_function(name, func, x_values, ref_values):
    print(f"\n=== Testing {name} ===")
    for x, ref in zip(x_values, ref_values):
        val = func(x, eta, g)
        diff = relerr(val, ref)
        status = "PASS" if diff < tol else "FAIL"
        print(f"x={x:8.2f}  val={val:.8e}  ref={ref:.8e}  diff={diff:.3e}  {status}")

# ------------------------------------------------------------
# Compute reference (high precision)
# ------------------------------------------------------------
ref_values = fermi_dirac_e_npvector(x_values)

# ------------------------------------------------------------
# Run tests
# ------------------------------------------------------------
#test_function("high_precision", fd_high_precision_scalar, x_values, ref_values)
test_function("scalar_if", fd_if_scalar, x_values, ref_values)
test_function("expit", fd_expit_scalar, x_values, ref_values)

# ------------------------------------------------------------
# Benchmark
# ------------------------------------------------------------
def benchmark(name, func, x_values, repeat=20000):
    t0 = time.time()
    for _ in range(repeat):
        # vectorized version for speed test
        if name == "npvector":
            fermi_dirac_e_npvector(x_values)
        elif name == "scalar_if":
            np.array([fermi_dirac_e_if(x, eta, g) for x in x_values])
        else:
            fermi_dirac_e(x_values, eta, g)
    t1 = time.time()
    return t1 - t0

print("\n=== Benchmark (lower is faster) ===")
t_np = benchmark("npvector", fd_high_precision_scalar, x_values)
t_if = benchmark("scalar_if", fd_if_scalar, x_values)
t_ex = benchmark("expit", fd_expit_scalar, x_values)

print(f"npvector : {t_np:.4f} sec")
print(f"scalar_if: {t_if:.4f} sec")
print(f"expit    : {t_ex:.4f} sec")
