"""
Fast and robust Fermi–Dirac integral library
"""

from math import exp, gamma
from functools import lru_cache
import numpy as np
from scipy import integrate
from scipy.special import expit, log_expit


xlim_exp = 40.0
#xlim_exp = 300.0

_FD_HALF_PREFAC = 2.0 / np.sqrt(np.pi)  # 2/sqrt(pi)
_DegenerateFD_z = (
    0.822467, 0.947033, 0.985551, 0.996233, 0.999040,
    0.999758, 0.999939, 0.999985, 0.999996, 0.999999
)


def _round_key(x, ndigit=10):
    return round(float(x), ndigit)


# ------------------------------------------------------------
# Fermi–Dirac occupations (x = E/kT, eta = Ef/kT)
# ------------------------------------------------------------
def fermi_dirac_e_npvector(x):
    x = np.clip(x, -xlim_exp, xlim_exp)
    pos = x > 0
    out = np.empty_like(x)
    exp_positive = np.exp(-x[pos])
    out[pos] = exp_positive / (1 + exp_positive)
    out[~pos] = 1 / (1 + np.exp(x[~pos]))
    return out

def fermi_dirac_e_if(x, eta, g=1.0):
    """
    f = 1 / (1 + g * exp(x - eta))
    """
    xe = x - eta
    if xe < -xlim_exp: return 1.0
    if xe > xlim_exp: return 0.0
    return 1.0 / (1 + g * exp(xe))

def fermi_dirac_e(x, eta, g=1.0):
    """
    f = 1 / (1 + g * exp(x - eta))
      = expit((eta - x) - ln(g))
    """
    return expit((eta - x) - np.log(g))

def log_fermi_dirac_e(x, eta, g=1.0):
    """log(f) stable for the same definition."""
    return log_expit((eta - x) - np.log(g))

def fermi_dirac_h(x, eta, g=1.0):
    """
    vacancy (hole) probability (defined with the same 'g as prefactor'):
        1 - f = 1 / (1 + g * exp(eta - x))
              = expit((x - eta) - ln(g))
    """
    return expit((x - eta) - np.log(g))

def log_fermi_dirac_h(x, eta, g=1.0):
    """log(1-f) stable for the same definition."""
    return log_expit((x - eta) - np.log(g))

def ionized_acceptor_frac(eta, xA, g=1.0):
    return fermi_dirac_e(xA, eta, g=g)

def ionized_donor_frac(eta, xD, g=1.0):
    """ND+/ND = 1 / (1 + g * exp(eta - xD))"""
    return fermi_dirac_h(xD, eta, g=g)

# ============================================================
# internal: degenerate limit (Sommerfeld-like)
# ============================================================

def _DegenerateFermiDirac(eta, r):
    gi = eta ** (r + 1.0) / (r + 1.0)
    t = r
    gg = t / (eta * eta)
    gs = gg * _DegenerateFD_z[0]
    t -= 1.0

    for k in range(1, 10):
        gg *= t * (t - 1.0) / (eta * eta)
        gs += gg * _DegenerateFD_z[k]
        t -= 2.0

    return gi * (1.0 + 2.0 * (r + 1.0) * gs)


# ============================================================
# internal: non-degenerate limit (alternating series)
# ============================================================

def _NonDegenerateFermiDirac(eta, r, reltol=1e-10, max_terms=200000):
    if eta < -20.0:
        return gamma(r + 1.0) * exp(eta)

    s = 1.0
    total = 0.0
    scale = 0.0

    for _ in range(max_terms//2):
        t1 = exp(s*eta) / s**(r+1.0); total += t1; scale = max(scale, abs(t1)); s += 1.0
        t2 = exp(s*eta) / s**(r+1.0); total -= t2; scale = max(scale, abs(t2)); s += 1.0
        if scale > 0 and abs(t2) < reltol * abs(total):
            break

    return total * gamma(r+1.0)


# ============================================================
# cache-safe wrapper
# ============================================================

def _func_Fr_x2(x, eta, r):
    return 2.0 * x ** (2.0 * r + 1.0) * expit(eta - x*x)


@lru_cache(maxsize=1024)
def _FermiIntegral_cached(eta, r, epsabs, epsrel, limit):
    # degenerate
    if eta >= 25.0 + 2.0 * r:
        return _DegenerateFermiDirac(eta, r)

    # non-degenerate
    if eta < -0.1:
        return _NonDegenerateFermiDirac(eta, r)

    # intermediate: z = x^2 + quad
    lim = 17.0 + 4.5 * r + eta
    lim = max(lim, 4.0) 
    f = lambda x: _func_Fr_x2(x, eta, r)

    ret = 0.0
    ret += integrate.quad(f, 0.0, 1.0, epsabs=epsabs, epsrel=epsrel, limit=limit)[0]
    ret += integrate.quad(f, 1.0, 4.0, epsabs=epsabs, epsrel=epsrel, limit=limit)[0]
    ret += integrate.quad(f, 4.0, lim, epsabs=epsabs, epsrel=epsrel, limit=limit)[0]

    return ret


# ============================================================
# public API
# ============================================================

def FermiIntegral_fast(
    eta, r,
    *,
    epsabs=1.0e-10,
    epsrel=1.0e-8,
    limit=50
):
    """
    General Fermi–Dirac integral F_r(eta)

    Parameters
    ----------
    eta : float
    r : float
    epsabs, epsrel, limit : quad parameters
    """

    key = (
        _round_key(eta),
        _round_key(r),
        epsabs,
        epsrel,
        limit,
    )
    return _FermiIntegral_cached(*key)


def FermiIntegral_half(
    eta,
    *,
    epsabs=1.0e-10,
    epsrel=1.0e-8,
    limit=50
):
    """F_{1/2}(eta)"""

    return FermiIntegral_fast(
        eta, 0.5,
        epsabs=epsabs,
        epsrel=epsrel,
        limit=limit
    )


def FermiIntegral_3half(
    eta,
    *,
    epsabs=1.0e-10,
    epsrel=1.0e-8,
    limit=50
):
    """F_{3/2}(eta)"""

    return FermiIntegral_fast(
        eta, 1.5,
        epsabs=epsabs,
        epsrel=epsrel,
        limit=limit
    )


# ============================================================
# carrier density helpers (Gamma-normalization NOT used)
# ============================================================

def electron_density(Nc, eta_c, *, epsabs=1.0e-10, epsrel=1.0e-8, limit=50):
    """
    Electron density (nonparabolicity ignored; standard parabolic band).

    This library defines the (non-normalized) FD integral:
        I_{1/2}(eta) = ∫_0^∞ ε^{1/2} / (1 + exp(ε - eta)) dε
    (NO division by Gamma).

    Therefore:
        n = Nc * (2/sqrt(pi)) * I_{1/2}(eta_c)

    Parameters
    ----------
    Nc : float
        Effective density of states in conduction band [cm^-3] or [m^-3]
        (unit is carried through).
    eta_c : float or array-like
        Dimensionless chemical potential for electrons:
            eta_c = (Ef - Ec) / (kT) = β(Ef - Ec)

    Returns
    -------
    n : float or ndarray
        Electron density in the same unit as Nc.
    """

    eta_arr = np.asarray(eta_c)
    if eta_arr.ndim == 0:
        return Nc * _FD_HALF_PREFAC * FermiIntegral_half(
            float(eta_arr), epsabs=epsabs, epsrel=epsrel, limit=limit
        )
    out = np.empty_like(eta_arr, dtype=float)
    it = np.nditer(eta_arr, flags=["multi_index"])
    for v in it:
        out[it.multi_index] = Nc * _FD_HALF_PREFAC * FermiIntegral_half(
            float(v), epsabs=epsabs, epsrel=epsrel, limit=limit
        )
    return out


def hole_density(Nv, eta_v, *, epsabs=1.0e-10, epsrel=1.0e-8, limit=50):
    """
    Hole density (standard parabolic valence band).

    Using the same non-normalized integral I_{1/2}:
        p = Nv * (2/sqrt(pi)) * I_{1/2}( beta(Ev - Ef) )

    If you pass:
        eta_v = beta(Ef - Ev)
    then beta(Ev - Ef) = -eta_v, so:
        p = Nv * (2/sqrt(pi)) * I_{1/2}(-eta_v)

    Parameters
    ----------
    Nv : float
        Effective density of states in valence band [cm^-3] or [m^-3]
        (unit is carried through).
    eta_v : float or array-like
        Dimensionless quantity defined as:
            eta_v = (Ef - Ev) / (kT) = β(Ef - Ev)

    Returns
    -------
    p : float or ndarray
        Hole density in the same unit as Nv.
    """

    eta_arr = np.asarray(eta_v)
    if eta_arr.ndim == 0:
        return Nv * _FD_HALF_PREFAC * FermiIntegral_half(
            float(-eta_arr), epsabs=epsabs, epsrel=epsrel, limit=limit
        )
    out = np.empty_like(eta_arr, dtype=float)
    it = np.nditer(eta_arr, flags=["multi_index"])
    for v in it:
        out[it.multi_index] = Nv * _FD_HALF_PREFAC * FermiIntegral_half(
            float(-v), epsabs=epsabs, epsrel=epsrel, limit=limit
        )
    return out


# ============================================================
# self-test
# ============================================================

if __name__ == "__main__":
    Nc = 1.0e19
    Nv = 1.0e19
    eta_list = [-10.0, -3.0, -1.0, 0.0, 2.0, 5.0, 10.0]

    print("Fermi–Dirac integral test")
    print("")

    print("r = 1/2")
    for eta in eta_list:
        print(f"  eta={eta:6.2f}  F1/2={FermiIntegral_half(eta):.8e}")

    print("\nr = 3/2")
    for eta in eta_list:
        n = electron_density(Nc, eta)
        p = hole_density(Nv, eta)
        print(f"  eta={eta:6.2f}  F3/2={FermiIntegral_3half(eta):.8e}  n={n:.8e}  p={p:.8e}")

    print("\ncache info:")
    print(_FermiIntegral_cached.cache_info())

