#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Fj_test_extreme.py

Fermi–Dirac integral (non-normalized) test & benchmark with edge/extreme cases.

We test the integral
    Fj(j, eta) = ∫_0^∞ x^j * f(eta - x) dx
where f(z) = 1/(1+exp(-z)) = expit(z).

Notes:
- For j = -1, the integral form diverges at x→0. We instead use the exact identity:
      F0(eta) = ∫_0^∞ f(eta-x) dx = log(1+exp(eta))
      d/deta F0(eta) = expit(eta)
  and define Fj(-1, eta) := expit(eta) (consistent with the recurrence dFj/deta = Fj-1).
- For other j > -1, we use numerical quadrature (quad) with a conservative finite upper bound.

Edge/limit checks added:
- Non-degenerate (eta << 0):  Fj ≈ exp(eta) * Γ(j+1)
- T→0 step limit (eta):      Fj_0K = ∫_0^{max(eta,0)} x^j dx = max(eta,0)^{j+1}/(j+1)
- Degenerate (eta >> 1): Sommerfeld (up to π^4 term where applicable)

Usage:
    python Fj_test_extreme.py
"""

import time
import math
import numpy as np
from scipy.special import expit, log1p, gamma
from scipy.integrate import quad

from tkfermi_integral import FermiIntegral_fast


# ============================================================
# Core implementations
# ============================================================
def _upper_bound(eta: float) -> float:
    """
    Heuristic finite upper bound for quad.
    - for large positive eta, integrand ~ x^j for x<eta, so upper must exceed eta.
    - for large x, occupation decays ~ exp(eta-x), so eta+50 is plenty in double precision.
    """
    return float(max(eta + 60.0, 120.0))


def Fj_expit(j: float, eta: float) -> float:
    """quad + expit integrand; special-cases j=-1."""
    if j == -1:
        return float(expit(eta))
    if j <= -1:
        raise ValueError("This implementation supports j > -1, except j=-1 (handled analytically).")

    upper = _upper_bound(eta)
    integrand = lambda x: (x ** j) * expit(eta - x)
    res, _ = quad(integrand, 0.0, upper, epsabs=1e-12, epsrel=1e-12, limit=200)
    return float(res)


def Fj_raw(j: float, eta: float) -> float:
    """quad + raw exp; special-cases j=-1."""
    if j == -1:
        # d/deta log(1+exp(eta)) = expit(eta)
        return float(expit(eta))
    if j <= -1:
        raise ValueError("This implementation supports j > -1, except j=-1 (handled analytically).")

    upper = _upper_bound(eta)
    integrand = lambda x: (x ** j) / (1.0 + np.exp(x - eta))
    res, _ = quad(integrand, 0.0, upper, epsabs=1e-12, epsrel=1e-12, limit=200)
    return float(res)


def Fj_fast(j: float, eta: float):
    """Wrapper for your fast implementation. Returns None if not supported."""
    try:
        return float(FermiIntegral_fast(eta, j))
    except Exception:
        return None


# ============================================================
# Asymptotic / limit reference checks
# ============================================================
def Fj_nondegenerate_approx(j: float, eta: float) -> float:
    """
    For eta << 0:
        expit(eta-x) ≈ exp(eta-x)
        Fj ≈ exp(eta) ∫_0^∞ x^j e^{-x} dx = exp(eta) Γ(j+1)
    """
    if j == -1:
        # expit(eta) ~ exp(eta) for eta<<0
        return float(np.exp(eta))
    return float(np.exp(eta) * gamma(j + 1.0))


def Fj_0K_step(j: float, eta: float) -> float:
    """
    T→0 limit of expit(eta-x) -> Θ(eta-x):
        Fj_0K = ∫_0^{max(eta,0)} x^j dx = max(eta,0)^(j+1)/(j+1)  (j>-1)
    For j=-1, Fj -> Θ(eta) (but we keep expit(eta) as smooth finite-T analogue).
    """
    if j == -1:
        return float(1.0 if eta > 0 else 0.0)
    if eta <= 0:
        return 0.0
    return float((eta ** (j + 1.0)) / (j + 1.0))


def Fj_sommerfeld(j: float, eta: float) -> float:
    """
    Sommerfeld expansion (non-normalized) for eta >> 1, j > -1.

    Leading term (0K):
        η^(j+1)/(j+1)
    + π^2/6 * j * η^(j-1)
    + 7π^4/360 * j*(j-1)*(j-2) * η^(j-3)   (only meaningful if j>=3 roughly)
    """
    if j == -1:
        # expit(eta) -> 1 for eta>>1
        return 1.0
    if eta <= 0:
        # not a good regime; still return 0K limit
        return Fj_0K_step(j, eta)

    term0 = (eta ** (j + 1.0)) / (j + 1.0)
    term2 = (math.pi**2 / 6.0) * j * (eta ** (j - 1.0)) if j > 0 else 0.0

    # π^4 term requires j-3 power; for small j this diverges / is meaningless.
    if j > 2.0:
        term4 = (7.0 * math.pi**4 / 360.0) * j * (j - 1.0) * (j - 2.0) * (eta ** (j - 3.0))
    else:
        term4 = 0.0

    return float(term0 + term2 + term4)


# ============================================================
# Test helpers
# ============================================================
def rel_err(a: float, b: float, tiny: float = 1e-300) -> float:
    return abs(a - b) / max(abs(b), tiny)


def print_header(title: str):
    print("\n" + "=" * 78)
    print(title)
    print("=" * 78)


# ============================================================
# Main tests
# ============================================================
def main():
    # Requested j list including -1
    j_values = [-1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0]

    # Baseline sweep
    eta_values = np.array([-20, -10, -5, -2, -1, -0.1, 0.0, 0.1, 1, 2, 5, 10, 20], dtype=float)

    # Extreme / edge cases
    eta_extreme = np.array([-100, -50, -40, 40, 50, 100], dtype=float)

    # Tolerance for cross-implementation numeric agreement (quad expit vs quad raw)
    tol_abs = 1e-10

    print_header("A) Consistency check: quad(expit) vs quad(raw) vs fast (if available)")
    print(f"abs diff tolerance (quad vs quad): {tol_abs:g}")
    print("-" * 78)
    print("   j      eta        quad_expit         quad_raw        |diff|        fast          status")
    print("-" * 78)

    for j in j_values:
        for eta in eta_values:
            v1 = Fj_expit(j, float(eta))
            v2 = Fj_raw(j, float(eta))
            d12 = abs(v1 - v2)

            v3 = Fj_fast(j, float(eta))
            if v3 is None:
                status = "PASS" if d12 < tol_abs else "FAIL"
                v3s = "  (n/a)"
            else:
                d13 = abs(v1 - v3)
                status = "PASS" if (d12 < tol_abs and d13 < tol_abs) else "FAIL"
                v3s = f"{v3: .6e}"

            print(f"{j:6.1f}  {eta:8.2f}  {v1: .6e}  {v2: .6e}  {d12: .2e}  {v3s:>12}   {status}")
        print("-" * 78)

    # ------------------------------------------------------------
    # Edge / limit checks
    # ------------------------------------------------------------
    print_header("B) Edge/limit checks: nondegenerate (eta<<0), 0K-step, Sommerfeld (eta>>0)")
    print("We compare quad(expit) against asymptotic/limit formulae.\n"
          "For very negative eta:  Fj ≈ exp(eta)*Gamma(j+1)\n"
          "For large positive eta: use 0K step and Sommerfeld expansion as references.\n")
    print("-" * 78)
    print("   j     eta      quad_expit        approx_type        approx_value      rel.err")
    print("-" * 78)

    # 1) nondegenerate checks
    for j in j_values:
        for eta in [-40.0, -50.0]:
            v = Fj_expit(j, eta)
            a = Fj_nondegenerate_approx(j, eta)
            print(f"{j:5.1f}  {eta:7.1f}  {v: .6e}    nondegenerate   {a: .6e}    {rel_err(v,a):.2e}")

    print("-" * 78)

    # 2) 0K-step check (finite eta, compare to step integral)
    for j in [x for x in j_values if x != -1.0]:
        for eta in [-5.0, -1.0, 0.0, 1.0, 5.0, 10.0]:
            v = Fj_expit(j, eta)
            a = Fj_0K_step(j, eta)
            print(f"{j:5.1f}  {eta:7.1f}  {v: .6e}    0K step        {a: .6e}    {rel_err(v,a):.2e}")

    print("-" * 78)

    # 3) Sommerfeld check (eta large)
    for j in [x for x in j_values if x not in (-1.0, 0.0)]:  # j=0 has special exact form; skip
        for eta in [20.0, 40.0]:
            v = Fj_expit(j, eta)
            a = Fj_sommerfeld(j, eta)
            print(f"{j:5.1f}  {eta:7.1f}  {v: .6e}    Sommerfeld     {a: .6e}    {rel_err(v,a):.2e}")

    # Special exact identity for j=0
    print("-" * 78)
    for eta in [-40.0, -10.0, 0.0, 10.0, 40.0]:
        v = Fj_expit(0.0, eta)
        a = float(np.log1p(np.exp(eta)))  # exact: ∫ f(eta-x) dx = log(1+exp(eta))
        print(f"{0.0:5.1f}  {eta:7.1f}  {v: .6e}    exact log1p    {a: .6e}    {rel_err(v,a):.2e}")

    # ------------------------------------------------------------
    # Benchmark
    # ------------------------------------------------------------
    print_header("C) Benchmark (per-call time)")
    N = 30  # repeat count

    j_bench = [0.5, 1.5]
    eta_bench = [-10.0, -1.0, 0.0, 1.0, 10.0, 50.0]

    for j in j_bench:
        print(f"\n[j={j}]")
        print("-" * 78)
        for eta in eta_bench:
            # expit quad
            t0 = time.perf_counter()
            for _ in range(N):
                Fj_expit(j, eta)
            t1 = time.perf_counter()
            time_expit = (t1 - t0) / N

            # raw quad
            t0 = time.perf_counter()
            for _ in range(N):
                Fj_raw(j, eta)
            t1 = time.perf_counter()
            time_raw = (t1 - t0) / N

            # fast (if any)
            t_fast = None
            ok_fast = True
            try:
                t0 = time.perf_counter()
                for _ in range(N * 20):
                    FermiIntegral_fast(eta, j)
                t1 = time.perf_counter()
                t_fast = (t1 - t0) / (N * 20)
            except Exception:
                ok_fast = False

            print(f"eta={eta:6.1f} | quad(expit)={time_expit*1e3:8.3f} ms  quad(raw)={time_raw*1e3:8.3f} ms"
                  + (f"  fast={t_fast*1e6:8.3f} µs" if ok_fast else "  fast=(n/a)"))
        print("-" * 78)

    # ------------------------------------------------------------
    # Extreme eta quick sanity (avoid too much time)
    # ------------------------------------------------------------
    print_header("D) Extreme eta sanity (quick)")
    print("We only run quad(expit) once per point for extreme eta to avoid long runtimes.")
    print("-" * 78)
    print("   j      eta        Fj_expit")
    print("-" * 78)
    for j in j_values:
        for eta in eta_extreme:
            v = Fj_expit(j, float(eta))
            print(f"{j:6.1f}  {eta:8.1f}  {v: .6e}")
        print("-" * 78)


if __name__ == "__main__":
    main()
