"""
Semiconductor Carrier Statistics Simulator (+ fitting mode)

Original base: 031-Ne-T_oneshot_gemini3flash.py

Added features:
- --mode calc|fit (fit uses Nelder–Mead by default)
- Fitting parameters: ED, ND, EA, NA, Eg
- --fix p1,p2,... to fix parameters (names: ED,ND,EA,NA,Eg)
- Defaults for wide-gap n-type: EA and NA are fixed unless you explicitly unfix them
- ND, NA are optimized in log10-space
- Parameter uncertainty estimation via numerical Jacobian at optimum:
  Cov(theta) ≈ s^2 * (J^T J)^(-1),  s^2 = RSS/(N-M)

Data file for fitting:
- CSV/TSV with columns: T, Ne   (header required)
  Example:
    T,Ne
    300,1.2e17
    350,1.8e17
"""

from __future__ import annotations

import sys
import numpy as np
import pandas as pd
import argparse
import matplotlib.pyplot as plt
import os
import json
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, Iterable, List, Tuple

from scipy import integrate, optimize, special
from scipy.constants import k, hbar, m_e, e

# --- constants ---
KbT_const = k / e  # [eV/K]

# FD integral cache control: eta is quantized to 1e-3 for speed.
# During numerical Jacobian evaluation, this quantization can make residuals piecewise-constant
# and lead to J≈0. Set FD_CACHE_DISABLE=True temporarily to use non-quantized integration.
FD_CACHE_DISABLE = False


# ============================================================
# core physics
# ============================================================

def m2Nc(T: float, me_eff: float) -> float:
    """Effective DOS Nc [cm^-3]."""
    val = 2 * (me_eff * m_e * k * T / (2 * np.pi * hbar**2)) ** 1.5
    return val * 1e-6  # m^-3 -> cm^-3


def m2Nv(T: float, mh_eff: float) -> float:
    """Effective DOS Nv [cm^-3]."""
    val = 2 * (mh_eff * m_e * k * T / (2 * np.pi * hbar**2)) ** 1.5
    return val * 1e-6


@lru_cache(maxsize=20000)
def _F_half_cached(eta_rounded_milli: int) -> float:
    """Cache for F_{1/2}(eta). Key is eta * 1000 rounded to int."""
    eta = eta_rounded_milli / 1000.0
    integrand = lambda x: np.sqrt(x) * special.expit(eta - x)
    # 0..500 is typically enough for eta in practical semiconductors
#    res, _ = integrate.quad(integrand, 0.0, np.inf, epsabs=1e-6, epsrel=1e-8, limit=200)
#    res, _ = integrate.quad(integrand, 0.0, 200.0, epsabs=1e-10, epsrel=1e-8, limit=200)
    res, _ = integrate.quad(integrand, 0.0, 500.0, epsabs=1e-10, epsrel=1e-8, limit=200)
    return float(res)


def F_half(eta: float) -> float:
    """Fermi–Dirac integral (order 1/2) without the prefactor 2/sqrt(pi).
    Note: by default we quantize eta to 1e-3 and cache for speed.
    If FD_CACHE_DISABLE is True, we compute without quantization (slower, but needed for Jacobians).
    """
    if FD_CACHE_DISABLE:
        integrand = lambda x: np.sqrt(x) * special.expit(eta - x)
        res, _ = integrate.quad(integrand, 0.0, 500.0, epsabs=1e-10, epsrel=1e-8, limit=200)
        return float(res)
    key = int(np.round(eta * 1000.0))
    return _F_half_cached(key)


def Ne(Ef: float, Ec: float, T: float, Nc: float) -> float:
    eta = (Ef - Ec) / (KbT_const * T)
    return Nc * (2.0 / np.sqrt(np.pi)) * F_half(eta)


def Nh(Ef: float, Ev: float, T: float, Nv: float) -> float:
    eta = (Ev - Ef) / (KbT_const * T)
    return Nv * (2.0 / np.sqrt(np.pi)) * F_half(eta)


def NDp(Ef: float, Ec: float, ED: float, ND: float, T: float, g: float = 1.0) -> float:
    """Ionized donor density ND+.

    Donor level: E_D = Ec - ED (ED>0 means deeper below Ec).
    Occupancy for donor level (neutral) is
        f_D = 1 / (1 + g * exp((E_D - Ef)/kT))
    so ionized donors are ND+ = ND * (1 - f_D).

    Notes:
      * g is the donor degeneracy factor (user-defined; default=1).
      * uses expit for overflow-safe evaluation.
    """
    if ND <= 0.0:
        return 0.0
    E_donor = Ec - ED
    x = (Ef - E_donor) / (KbT_const * T) - np.log(max(g, 1e-300))
    fD = special.expit(x)  # neutral donor occupancy
    return ND * (1.0 - fD)


def NAm(Ef: float, Ev: float, EA: float, NA: float, T: float, g: float = 1.0) -> float:
    """Ionized acceptor density NA-.

    Acceptor level: E_A = Ev + EA (EA>0 means deeper above Ev).
    Ionized acceptors correspond to occupied acceptor states:
        f_A = 1 / (1 + g * exp((E_A - Ef)/kT))
        NA- = NA * f_A

    Notes:
      * g is the acceptor degeneracy factor (user-defined; default=1).
      * uses expit for overflow-safe evaluation.
    """
    if NA <= 0.0:
        return 0.0
    E_acc = Ev + EA
    x = (Ef - E_acc) / (KbT_const * T) - np.log(max(g, 1e-300))
    fA = special.expit(x)  # ionized acceptor fraction
    return NA * fA


def deltaQ(Ef: float, Ec: float, Ev: float, ED: float, EA: float, ND: float, NA: float,
           T: float, Nc: float, Nv: float, gD: float = 1.0, gA: float = 1.0) -> float:
    """Charge neutrality residual: p - n + ND+ - NA-.

    Notes
    -----
    The degeneration parameters gD/gA must be used consistently both
    when solving Ef and when post-evaluating ND+/NA-.
    """
    n = Ne(Ef, Ec, T, Nc)
    p = Nh(Ef, Ev, T, Nv)
    ndp = NDp(Ef, Ec, ED, ND, T, g=gD)
    nam = NAm(Ef, Ev, EA, NA, T, g=gA)
    return p - n + ndp - nam


def solve_Ef(Ec: float, Ev: float, ED: float, EA: float, ND: float, NA: float,
             T: float, Nc: float, Nv: float,
             gD: float = 1.0, gA: float = 1.0,
             low0: float = -5.0, high0: float = 5.0,
             max_expand: int = 50) -> float:
    """Solve Ef by brentq with automatic bracket expansion."""
    low, high = low0, high0
    f_low = deltaQ(low, Ec, Ev, ED, EA, ND, NA, T, Nc, Nv, gD, gA)
    f_high = deltaQ(high, Ec, Ev, ED, EA, ND, NA, T, Nc, Nv, gD, gA)

    expand = 0
    while f_low * f_high > 0 and expand < max_expand:
        low -= 1.0
        high += 1.0
        f_low = deltaQ(low, Ec, Ev, ED, EA, ND, NA, T, Nc, Nv, gD, gA)
        f_high = deltaQ(high, Ec, Ev, ED, EA, ND, NA, T, Nc, Nv, gD, gA)
        expand += 1

    if f_low * f_high > 0:
        raise RuntimeError(f"Failed to bracket Ef at T={T:.3g} K (residual has same sign).")

    return float(optimize.brentq(
        deltaQ, low, high,
        args=(Ec, Ev, ED, EA, ND, NA, T, Nc, Nv, gD, gA),
        xtol=1e-12, rtol=1e-10, maxiter=200
    ))



@dataclass
class Params:
    Eg: float
    ED: float
    EA: float
    ND: float
    NA: float
    me: float
    mh: float


def calc_at_T(T: float, p: Params, args=None) -> Dict[str, float]:
    Nc = m2Nc(T, p.me)
    Nv = m2Nv(T, p.mh)
    # degeneracy factors for donor/acceptor ionization (user responsibility)
    gD = float(getattr(args, 'gD', 1.0)) if args is not None else 1.0
    gA = float(getattr(args, 'gA', 1.0)) if args is not None else 1.0
    Ec, Ev = p.Eg, 0.0
    Ef = solve_Ef(Ec, Ev, p.ED, p.EA, p.ND, p.NA, T, Nc, Nv, gD=gD, gA=gA,
                 low0=-5.0, high0=p.Eg + 5.0)

    n = Ne(Ef, Ec, T, Nc)
    h = Nh(Ef, Ev, T, Nv)
    ndp = NDp(Ef, Ec, p.ED, p.ND, T)
    nam = NAm(Ef, Ev, p.EA, p.NA, T)
    return {"T": T, "Ef": Ef, "n": n, "p": h, "NDp": ndp, "NAm": nam, "Nc": Nc, "Nv": Nv}


# ============================================================
# fitting utilities
# ============================================================

FIT_NAMES = ["ED", "ND", "EA", "NA", "Eg"]


def parse_fix_list(s: str | None) -> set[str]:
    if not s:
        return set()
    items = []
    for part in s.split(","):
        part = part.strip()
        if part:
            items.append(part)
    return set(items)

def sanitize_args_for_fit(args, fix: set[str]) -> None:
    """Sanitize CLI args after applying --load and --fix.

    - ND/NA are optimized in log10 space when free -> must be > 0.
    - If ND/NA are fixed, allowing 0 is useful to represent 'no dopant'.
    """
    # Guard against non-positive densities when they are FREE variables
    if ("ND" not in fix) and (getattr(args, "ND", 0.0) <= 0.0):
        print(f"  Warning: ND must be > 0 for log-optimization. Use ND=1.0 (given ND={args.ND}).")
        args.ND = 1.0
    if ("NA" not in fix) and (getattr(args, "NA", 0.0) <= 0.0):
        print(f"  Warning: NA must be > 0 for log-optimization. Use NA=1.0 (given NA={args.NA}).")
        args.NA = 1.0

    # Basic sanity for energies
    if getattr(args, "Eg", 0.0) <= 0.0:
        print(f"  Warning: Eg must be > 0. Use Eg=1.0 (given Eg={args.Eg}).")
        args.Eg = 1.0
    if getattr(args, "ED", 0.0) < 0.0:
        print(f"  Warning: ED must be >= 0. Use ED=0.0 (given ED={args.ED}).")
        args.ED = 0.0
    if getattr(args, "EA", 0.0) < 0.0:
        print(f"  Warning: EA must be >= 0. Use EA=0.0 (given EA={args.EA}).")
        args.EA = 0.0

    # Degeneracy parameters must be positive if used in log()
    if getattr(args, "gD", 1.0) <= 0.0:
        print(f"  Warning: gD must be > 0. Use gD=1.0 (given gD={args.gD}).")
        args.gD = 1.0
    if getattr(args, "gA", 1.0) <= 0.0:
        print(f"  Warning: gA must be > 0. Use gA=1.0 (given gA={args.gA}).")
        args.gA = 1.0



def read_data(path: str, temp_col: int = 0, ne_col: int = 3, sheet_name: str | int | None = None) -> Tuple[np.ndarray, np.ndarray, pd.DataFrame]:
    """
    Read experimental data from CSV or Excel.
    - Default: temperature is 1st column (0), carrier density is 4th column (3).
    - Column indices are 0-based.
    Returns: (T, Ne, df)
    """
    if path is None:
        raise ValueError("input file path is None")
    if not os.path.exists(path):
        raise FileNotFoundError(f"input file not found: {path}")

    ext = os.path.splitext(path)[1].lower()
    if ext in [".xls", ".xlsx", ".xlsm"]:
        # pandas.read_excel(sheet_name=None) returns a dict of DataFrames (all sheets).
        # Our CLI default means "first sheet".
        sn = sheet_name
        if sn is None:
            sn = 0
        # Allow "0", "1" ... as sheet index via --sheet
        if isinstance(sn, str) and sn.strip().isdigit():
            sn = int(sn.strip())

        df = pd.read_excel(path, sheet_name=sn)
        if isinstance(df, dict):
            # Fallback: pick the first sheet deterministically
            first_key = next(iter(df.keys()))
            df = df[first_key]
    else:
        # CSV/TSV (auto-sep if possible)
        try:
            df = pd.read_csv(path, sep=None, engine="python")
        except Exception:
            df = pd.read_csv(path)

    if df.shape[1] <= max(temp_col, ne_col):
        raise ValueError(f"Not enough columns in data. shape={df.shape}, temp_col={temp_col}, ne_col={ne_col}")

    T = df.iloc[:, temp_col].to_numpy(dtype=float)
    Ne = df.iloc[:, ne_col].to_numpy(dtype=float)

    # Drop NaNs
    mask = np.isfinite(T) & np.isfinite(Ne)
    T = T[mask]
    Ne = Ne[mask]

    return T, Ne, df



def save_params(params: dict, filename: str = "fit_params.json"):
    with open(filename, "w", encoding="utf-8") as f:
        json.dump(params, f, indent=4)
    print(f"Saved parameters: {filename}")

def load_params(filename: str = "fit_params.json") -> dict:
    if os.path.exists(filename):
        with open(filename, "r", encoding="utf-8") as f:
            return json.load(f)
    # fallback to empty
    return {}


def _cli_overrides(argv: List[str]) -> set[str]:
    """Return a set of argparse destination names explicitly provided on CLI.

    Supports both '--opt value' and '--opt=value' forms.
    """
    # map option strings to argparse dest names (only those relevant for sim/fit params)
    opt2dest = {
        "--Eg": "Eg",
        "--ED": "ED",
        "--EA": "EA",
        "--ND": "ND",
        "--NA": "NA",
        "--me": "me",
        "--mh": "mh",
        "--gD": "gD",
        "--gA": "gA",
        "--temp_col": "temp_col",
        "--ne_col": "ne_col",
        "--sheet": "sheet",
        "--Tmin": "Tmin",
        "--Tmax": "Tmax",
        "--nT": "nT",
        "--Nmin": "Nmin",
        "--Nmax": "Nmax",
        "--fix": "fix",
        "--unfix-defaults": "unfix_defaults",
        "--method": "method",
        "--load": "load",
        "--save": "save",
        "--input": "input",
        "--mode": "mode",
    }

    overrides: set[str] = set()
    i = 0
    while i < len(argv):
        a = argv[i]
        if a.startswith("--"):
            # --opt=value
            if "=" in a:
                opt = a.split("=", 1)[0]
                if opt in opt2dest:
                    overrides.add(opt2dest[opt])
            else:
                opt = a
                if opt in opt2dest:
                    overrides.add(opt2dest[opt])
                # skip next value if it is a value (not another flag)
                if i + 1 < len(argv) and (not argv[i + 1].startswith("--")):
                    i += 1
        i += 1
    return overrides

def _apply_loaded_params_to_args(args, params: dict, cli_overrides: set[str]) -> None:
    """Apply params dict to args, but do NOT overwrite CLI-provided values."""
    # only allow known numeric parameters
    allow = {"Eg", "ED", "EA", "ND", "NA", "me", "mh", "gD", "gA"}
    for k, v in params.items():
        if k not in allow:
            continue
        if k in cli_overrides:
            continue
        try:
            setattr(args, k, float(v))
        except Exception:
            # ignore malformed values
            continue

def pack_free_params(args, fix: set[str]) -> Tuple[np.ndarray, List[str]]:
    """
    Return initial vector theta0 in the internal parameterization:
      ED, Eg, EA : linear
      ND, NA     : log10
    but only for non-fixed parameters.
    """
    theta = []
    names = []
    for name in FIT_NAMES:
        if name in fix:
            continue
        if name in ("ND", "NA"):
            v = getattr(args, name)
            if v <= 0:
                raise ValueError(f"{name} must be > 0")
            theta.append(np.log10(v))
            names.append(f"log10({name})")
        else:
            theta.append(float(getattr(args, name)))
            names.append(name)
    return np.array(theta, dtype=float), names


def unpack_params(theta: np.ndarray, args, fix: set[str]) -> Params:
    """Map internal theta (free) + fixed values (args) -> physical Params."""
    values: Dict[str, float] = {}
    i = 0
    for name in FIT_NAMES:
        if name in fix:
            values[name] = float(getattr(args, name))
        else:
            if name in ("ND", "NA"):
                values[name] = 10.0 ** float(theta[i])
                i += 1
            else:
                values[name] = float(theta[i])
                i += 1

    return Params(
        Eg=float(values["Eg"]),
        ED=float(values["ED"]),
        EA=float(values["EA"]),
        ND=float(values["ND"]),
        NA=float(values["NA"]),
        me=float(args.me),
        mh=float(args.mh),
    )


def residuals_log10(T: np.ndarray, Ne_obs: np.ndarray, p: Params, args=None) -> np.ndarray:
    """Residual vector in log10 space (relative error oriented)."""
    # Guard
    if np.any(Ne_obs <= 0):
        raise ValueError("Observed Ne must be > 0 for log residuals.")
    n_model = np.array([calc_at_T(float(t), p, args)["n"] for t in T], dtype=float)
    # avoid log(0)
    n_model = np.maximum(n_model, 1e-300)
    return np.log10(n_model) - np.log10(Ne_obs)


def objective(theta: np.ndarray, T: np.ndarray, Ne_obs: np.ndarray, args, fix: set[str]) -> float:
    p = unpack_params(theta, args, fix)
    # simple physical sanity (penalty for Nelder–Mead)
    if p.Eg <= 0 or p.ED < 0 or p.EA < 0:
        return 1e30
    if ("ND" not in fix) and (p.ND <= 0):
        return 1e30
    if ("NA" not in fix) and (p.NA <= 0):
        return 1e30
    # Often ED,EA should be smaller than Eg; soft penalty if violated
    pen = 0.0
    if p.ED > p.Eg:
        pen += (p.ED - p.Eg) ** 2
    if p.EA > p.Eg:
        pen += (p.EA - p.Eg) ** 2
    r = residuals_log10(T, Ne_obs, p, args)
    return float(np.dot(r, r) + 1e3 * pen)


def numerical_jacobian(theta: np.ndarray, T: np.ndarray, Ne_obs: np.ndarray, args, fix: set[str]) -> np.ndarray:
    """
    Numerical Jacobian of residual vector (not objective) wrt theta.
    Central differences.

    IMPORTANT:
    - This codebase caches F_{1/2}(eta) by quantizing eta to 1e-3 for speed.
    - For finite-difference Jacobians, that quantization can make residuals piecewise-constant
      under small parameter perturbations, producing J≈0 and thus (J^T J)≈0.
    - Therefore we temporarily disable the quantized-cache (FD_CACHE_DISABLE=True) here.
    """
    global FD_CACHE_DISABLE
    old_flag = FD_CACHE_DISABLE
    FD_CACHE_DISABLE = True
    try:
        r0 = residuals_log10(T, Ne_obs, unpack_params(theta, args, fix), args)
        m = r0.size
        n = theta.size
        J = np.zeros((m, n), dtype=float)

        rel_step = float(getattr(args, 'jac_relstep', 1e-6))
        abs_step = float(getattr(args, 'jac_absstep', 1e-12))

        for j in range(n):
            tj = theta[j]
            # step size: relative
            h = rel_step * (abs(tj) + 1.0) + abs_step
            th_p = theta.copy(); th_p[j] = tj + h
            th_m = theta.copy(); th_m[j] = tj - h
            rp = residuals_log10(T, Ne_obs, unpack_params(th_p, args, fix), args)
            rm = residuals_log10(T, Ne_obs, unpack_params(th_m, args, fix), args)
            J[:, j] = (rp - rm) / (2.0 * h)
        return J
    finally:
        FD_CACHE_DISABLE = old_flag

def estimate_param_errors(theta_hat: np.ndarray, T: np.ndarray, Ne_obs: np.ndarray, args, fix: set[str],
                          free_names: List[str]) -> Dict[str, Dict[str, float]]:
    """
    Return dict with 1-sigma errors for free parameters and for original params (ND,NA).
    Cov(theta) in the internal parameterization.
    """
    r = residuals_log10(T, Ne_obs, unpack_params(theta_hat, args, fix), args)
    N = r.size
    M = theta_hat.size
    J = numerical_jacobian(theta_hat, T, Ne_obs, args, fix)

    # Cov(theta) ≈ s^2 (J^T J)^-1
    JTJ = J.T @ J
    # regularize if ill-conditioned
    ridge = 1e-12 * np.trace(JTJ) / max(1, M)
    JTJ_reg = JTJ + ridge * np.eye(M)

    try:
        JTJ_inv = np.linalg.inv(JTJ_reg)
    except np.linalg.LinAlgError:
        JTJ_inv = np.linalg.pinv(JTJ_reg)

    dof = max(1, N - M)
    s2 = float(np.dot(r, r) / dof)
    Cov = s2 * JTJ_inv
    sig = np.sqrt(np.maximum(np.diag(Cov), 0.0))

    out: Dict[str, Dict[str, float]] = {}
    for name, val, se in zip(free_names, theta_hat, sig):
        out[name] = {"estimate": float(val), "stderr": float(se)}

    # Also provide physical-param view (propagate for ND/NA from log10)
    phys = unpack_params(theta_hat, args, fix)
    # locate indices for log10(ND), log10(NA) if present
    for base in ("ND", "NA"):
        key = f"log10({base})"
        if key in out:
            mu = out[key]["estimate"]
            se = out[key]["stderr"]
            # ND = 10^mu ; dND ≈ ln(10)*10^mu*se
            est = 10.0 ** mu
            stderr = np.log(10.0) * est * se
            out[base] = {"estimate": float(est), "stderr": float(stderr)}
    # add ED/EA/Eg already physical if present
    for base in ("ED", "EA", "Eg"):
        if base in out:
            out[base + "_phys"] = out[base].copy()

    # add correlation matrix for diagnostics
    denom = np.outer(sig, sig)
    with np.errstate(divide='ignore', invalid='ignore'):
        corr = Cov / denom
    # diag should be 1 even if sig==0
    np.fill_diagonal(corr, 1.0)
    # where denom==0 -> undefined correlation
    z = denom == 0
    corr[z] = np.nan
    np.fill_diagonal(corr, 1.0)
    out["_cov_theta"] = {"matrix": Cov}
    out["_corr_theta"] = {"matrix": corr}
    out["_JTJ"] = {"matrix": JTJ}
    out["_sigma2"] = {"value": s2, "dof": dof}
    return out




# ============================================================
# added: covariance/correlation eigen diagnostics + suggestions + prediction band
# ============================================================
def cov_to_corr(Cov: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Return (Corr, std) from covariance."""
    Cov = np.asarray(Cov, dtype=float)
    std = np.sqrt(np.maximum(np.diag(Cov), 0.0))
    denom = np.outer(std, std)
    with np.errstate(divide="ignore", invalid="ignore"):
        Corr = Cov / denom
    np.fill_diagonal(Corr, 1.0)
    Corr[denom == 0] = np.nan
    np.fill_diagonal(Corr, 1.0)
    return Corr, std

def eigen_sorted_sym(A: np.ndarray, descending: bool = True) -> Tuple[np.ndarray, np.ndarray]:
    """Eigen-decomposition of symmetric matrix; returns (evals, evecs) sorted."""
    A = np.asarray(A, dtype=float)
    evals, evecs = np.linalg.eigh(A)
    idx = np.argsort(evals)
    if descending:
        idx = idx[::-1]
    return evals[idx], evecs[:, idx]

def summarize_eigenvectors(evals: np.ndarray, evecs: np.ndarray, names: List[str], topk: int = 3, compk: int = 4):
    """Human-readable summary of leading eigenvectors."""
    names = list(names)
    out = []
    k = min(topk, len(evals))
    for r in range(k):
        v = evecs[:, r]
        order = np.argsort(np.abs(v))[::-1]
        comps = [(names[j], float(v[j])) for j in order[:min(compk, len(order))]]
        out.append({"rank": r+1, "eigenvalue": float(evals[r]), "components": comps})
    return out

def propose_fix_candidates(names: List[str], theta_hat: np.ndarray, stderr: np.ndarray, corr: np.ndarray,
                           evals_cov: np.ndarray, evecs_cov: np.ndarray,
                           corr_thr: float = 0.95, relerr_thr: float = 0.5, topn: int = 3):
    """Heuristic suggestions for parameters to fix / constrain (internal parameterization)."""
    names = list(names)
    theta_hat = np.asarray(theta_hat, dtype=float)
    stderr = np.asarray(stderr, dtype=float)
    tiny = 1e-30
    relerr = stderr / (np.abs(theta_hat) + tiny)

    score = relerr.copy()
    reasons = {n: [f"relative stderr = {relerr[i]:.3g}"] for i, n in enumerate(names)}

    # strong pairwise correlations
    for i in range(len(names)):
        for j in range(i+1, len(names)):
            cij = corr[i, j]
            if not np.isfinite(cij):
                continue
            if abs(cij) >= corr_thr:
                bonus = (abs(cij) - corr_thr) * 5.0 + 0.5
                score[i] += bonus
                score[j] += bonus
                reasons[names[i]].append(f"strong corr with {names[j]}: {cij:+.3f} (>= {corr_thr})")
                reasons[names[j]].append(f"strong corr with {names[i]}: {cij:+.3f} (>= {corr_thr})")

    # dominant in the most-uncertain covariance direction
    if evecs_cov is not None and evecs_cov.size:
        v = evecs_cov[:, 0]
        w = np.abs(v)
        w = w / (w.max() + tiny)
        for i, n in enumerate(names):
            if w[i] >= 0.5:
                score[i] += 0.7 * w[i]
                reasons[n].append(f"dominant in most-uncertain eigen-direction: |v|/max={w[i]:.2f}")

    items = []
    for i, n in enumerate(names):
        items.append({
            "param": n,
            "score": float(score[i]),
            "estimate": float(theta_hat[i]),
            "stderr": float(stderr[i]),
            "relerr": float(relerr[i]) if np.isfinite(relerr[i]) else None,
            "reasons": reasons[n],
        })
    items.sort(key=lambda d: d["score"], reverse=True)

    filtered = []
    for it in items:
        if (it["relerr"] is not None and it["relerr"] >= relerr_thr) or any("strong corr" in r for r in it["reasons"]):
            filtered.append(it)
    if not filtered:
        filtered = items[:min(topn, len(items))]
    return filtered[:min(topn, len(filtered))]

def prediction_band_log10_Ne(T: np.ndarray, theta_hat: np.ndarray, args, fix: set[str], Cov: np.ndarray,
                             nsigma: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Delta-method band for y(T)=log10(Ne_model(T)) in internal parameterization."""
    T = np.asarray(T, dtype=float)
    p0 = unpack_params(theta_hat, args, fix)
    y0 = np.log10(np.maximum(np.array([calc_at_T(float(t), p0, args)["n"] for t in T], dtype=float), 1e-300))

    m = theta_hat.size
    if Cov is None or m == 0:
        return y0, None, None

    rel = float(getattr(args, "jac_relstep", 1e-4))
    ab  = float(getattr(args, "jac_absstep", 1e-8))
#    rel = float(getattr(args, "jac_relstep", 1e-6))
#    ab  = float(getattr(args, "jac_absstep", 1e-12))

    G = np.zeros((len(T), m), dtype=float)  # dy/dtheta
    for j in range(m):
        tj = theta_hat[j]
        h = rel * (abs(tj) + 1.0) + ab
        th_p = theta_hat.copy(); th_p[j] = tj + h
        th_m = theta_hat.copy(); th_m[j] = tj - h

        pp = unpack_params(th_p, args, fix)
        pm = unpack_params(th_m, args, fix)

        yp = np.log10(np.maximum(np.array([calc_at_T(float(t), pp, args)["n"] for t in T], dtype=float), 1e-300))
        ym = np.log10(np.maximum(np.array([calc_at_T(float(t), pm, args)["n"] for t in T], dtype=float), 1e-300))
        G[:, j] = (yp - ym) / (2.0 * h)

    var_y = np.einsum("ni,ij,nj->n", G, Cov, G)
    sig_y = np.sqrt(np.maximum(var_y, 0.0))
    ylo = y0 - nsigma * sig_y
    yhi = y0 + nsigma * sig_y
    return y0, ylo, yhi

# ============================================================
# CLI / main
# ============================================================

def initialize():
    parser = argparse.ArgumentParser(description="Hall carrier density Ne(T) simulator / fitting (Nelder–Mead default)")

    # I/O (match mu-fit style)
    parser.add_argument("--input", type=str, default="Hall-T1.xlsx", help="入力ファイル名 (CSV or Excel)")
    parser.add_argument("--sheet", type=str, default=None, help="Excel sheet name (default: first sheet)")
    parser.add_argument("--temp_col", type=int, default=0, help="温度列(0開始) [default: 0]")
    parser.add_argument("--ne_col", type=int, default=3, help="キャリア濃度列(0開始) [default: 3 (=4th col)]")

    # mode (calc -> sim, add read)
    parser.add_argument("--mode", type=str, choices=["read", "sim", "fit"], default="read",
                        help="read: データ表示/プロット, sim: 計算, fit: パラメータフィット")

    # physical parameters (initial values / sim mode parameters)
    parser.add_argument("--Eg", type=float, default=1.12, help="Bandgap [eV]")
    parser.add_argument("--me", type=float, default=1.08, help="Electron effective mass / m0")
    parser.add_argument("--mh", type=float, default=0.55, help="Hole effective mass / m0")
    parser.add_argument("--ND", type=float, default=1e17, help="Donor density [cm^-3]")
    parser.add_argument("--NA", type=float, default=1e15, help="Acceptor density [cm^-3]")
    parser.add_argument("--ED", type=float, default=0.045, help="Donor ionization energy below Ec [eV]")
    parser.add_argument("--EA", type=float, default=0.045, help="Acceptor ionization energy above Ev [eV]")

    # sweep (used when input is not provided or in pure sim mode)
    parser.add_argument("--Tmin", type=float, default=50, help="Min Temperature [K]")
    parser.add_argument("--Tmax", type=float, default=600, help="Max Temperature [K]")
    parser.add_argument("--nT", type=int, default=50, help="Number of temperature points (sim without data)")
    parser.add_argument("--Nmin", type=float, default=1e10, help="Plot Min Density [cm^-3]")
    parser.add_argument("--Nmax", type=float, default=1e22, help="Plot Max Density [cm^-3]")

    # fitting options
    parser.add_argument("--method", type=str, default="nelder-mead",
                        help="Optimization method for fit mode (default: nelder-mead)")
    parser.add_argument("--fix", type=str, default="",
                        help="固定するパラメータ名をカンマ区切りで指定 (例: EA,NA,Eg). "
                             "対象: ED,ND,EA,NA,Eg. 既定ではEA,NAは固定。")

    parser.add_argument("--unfix-defaults", action="store_true",
                    help="If set, do NOT auto-fix EA and NA (only --fix is applied).")

    # added: fitting temperature window (for mode=fit)
    parser.add_argument("--Tfitmin", type=float, default=-1e100, help="フィットに使う最小温度[K] (default -1e100)")
    parser.add_argument("--Tfitmax", type=float, default=+1e100, help="フィットに使う最大温度[K] (default +1e100)")

    # added: uncertainty band / Jacobian controls
    parser.add_argument("--band_sigma", type=float, default=1.0, help="誤差帯の幅 (nsigma). default 1.0 (±1σ)")
    parser.add_argument("--jac_relstep", type=float, default=1e-6, help="数値微分の相対ステップ")
    parser.add_argument("--jac_absstep", type=float, default=1e-12, help="数値微分の絶対ステップ")

    # added: diagnostic outputs
    parser.add_argument("--diag_save", type=str, default="fit_diagnostics_opt.json", help="診断出力(JSON)ファイル名")
    parser.add_argument("--suggest_save", type=str, default="fit_fix_suggestions.json", help="固定候補提案(JSON)ファイル名")

    parser.add_argument("--save", type=str, default="fit_params.json", help="fit結果保存ファイル (json)")
    parser.add_argument("--load", type=str, default="", help="初期値読込ファイル (json)")

    return parser.parse_args()

def run_read(args):
    # read and show data
    T_obs, Ne_obs, df = read_data(args.input, temp_col=args.temp_col, ne_col=args.ne_col,
                                  sheet_name=args.sheet if args.sheet else None)
    print("\n--- 読み込みデータ ---")
    print(df)

    fig, ax = plt.subplots(figsize=(6.6, 5.0))
    ax.semilogy(T_obs, Ne_obs, "o", label="data Ne")
    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Electron density Ne (cm$^{-3}$)")
    ax.grid(True, which="both", alpha=0.2)
    ax.legend()
    plt.title("Experimental data: Ne(T)")
    fig.tight_layout()
    plt.show()

def run_sim(args):
    # Load params (json). Priority: JSON -> CLI overrides (CLI wins).
    # - If --load is given, load that file.
    # - Else, if default save file exists (args.save; default 'fit_params.json'), auto-load it for convenience.
    load_path = getattr(args, "load", "") or ""
    autoloaded = False
    if not load_path:
        cand = getattr(args, "save", "fit_params.json")
        if cand and os.path.exists(cand):
            load_path = cand
            autoloaded = True
    if load_path:
        try:
            if os.path.exists(load_path):
                params = load_params(load_path)
                cli_over = _cli_overrides(sys.argv[1:])
                _apply_loaded_params_to_args(args, params, cli_over)
                if autoloaded and getattr(args, "load", "") == "":
                    print(f"Auto-loaded params for simulation: {load_path}")
                else:
                    print(f"Loaded params for simulation: {load_path}")
        except Exception as ex:
            print(f"Warning: failed to load params for simulation from '{load_path}': {ex}")

    # If input exists, use its T for simulation and optionally overlay.
    T_data = None
    Ne_data = None
    if args.input and os.path.exists(args.input):
        try:
            T_data, Ne_data, _df = read_data(args.input, temp_col=args.temp_col, ne_col=args.ne_col,
                                             sheet_name=args.sheet if args.sheet else None)
        except Exception:
            T_data, Ne_data = None, None

    if T_data is not None and T_data.size > 0:
        T = np.array(T_data, dtype=float)
    else:
        T = np.linspace(args.Tmin, args.Tmax, args.nT)

    p = Params(Eg=args.Eg, me=args.me, mh=args.mh, ND=args.ND, NA=args.NA, ED=args.ED, EA=args.EA)
    n = np.array([calc_at_T(float(t), p, args)["n"] for t in T], dtype=float)

    fig, ax = plt.subplots(figsize=(6.6, 5.0))
    if T_data is not None and Ne_data is not None:
        ax.semilogy(T_data, Ne_data, "o", label="data Ne")
    ax.semilogy(T, n, "-", label="sim Ne(T)")
    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Electron density Ne (cm$^{-3}$)")
    ax.set_ylim(args.Nmin, args.Nmax)
    ax.grid(True, which="both", alpha=0.2)
    ax.legend()
    plt.title("Simulation: Ne(T)")
    fig.tight_layout()
    plt.show()

def run_fit(args):
    # ------------------------------------------------------
    # 0) load data
    # ------------------------------------------------------
    if not args.input:
        raise SystemExit("--input is required in fit mode")

    T_obs, Ne_obs, _df = read_data(
        args.input,
        temp_col=args.temp_col,
        ne_col=args.ne_col,
        sheet_name=args.sheet if args.sheet else None,
    )

    # ------------------------------------------------------
    # 1) fitting temperature window
    # ------------------------------------------------------
    mfit = (T_obs >= args.Tfitmin) & (T_obs <= args.Tfitmax)
    if int(np.sum(mfit)) < 3:
        raise SystemExit(
            "fit mode: not enough points in the selected Tfit range. Adjust --Tfitmin/--Tfitmax."
        )
    T_fit = T_obs[mfit]
    Ne_fit = Ne_obs[mfit]

    print()
    print(f"--- Fit data points: {int(np.sum(mfit))}/{len(T_obs)} ---")
    if not (args.Tfitmin <= -1e50 and args.Tfitmax >= 1e50):
        print(f"  Using T range: {args.Tfitmin} .. {args.Tfitmax} K")

    # ------------------------------------------------------
    # 2) optionally load initial parameters (do not override CLI)
    # ------------------------------------------------------
    init_dict = {}
    if getattr(args, "load", ""):
        try:
            init_dict = load_params(args.load)
            print(f"Loaded initial params: {args.load}")
        except Exception as ex:
            print(f"Warning: failed to load '{args.load}': {ex}")
            init_dict = {}
    else:
        # convenience: if --save exists, use it as initial guess
        if getattr(args, "save", "") and os.path.exists(args.save):
            try:
                init_dict = load_params(args.save)
                print(f"Loaded initial params: {args.save}")
            except Exception:
                init_dict = {}

    if init_dict:
        cli_over = _cli_overrides(sys.argv[1:])
        _apply_loaded_params_to_args(args, init_dict, cli_over)

    # ------------------------------------------------------
    # 3) fixed parameters
    # ------------------------------------------------------
    fix = set()
    if args.fix.strip():
        fix = {s.strip() for s in args.fix.split(",") if s.strip()}

    for name in list(fix):
        if name not in FIT_NAMES:
            raise SystemExit(
                f"--fix contains unknown name '{name}'. Allowed: {', '.join(FIT_NAMES)}"
            )

    sanitize_args_for_fit(args, fix)

    theta0, free_names = pack_free_params(args, fix)

    print("\n--- Initial parameters (args) ---")
    for nm in FIT_NAMES:
        v = getattr(args, nm)
        tag = "(FIXED)" if nm in fix else ""
        print(f"  {nm:3s} = {v:.6g} {tag}")

    print("\nfree parameters:", free_names if free_names else "(none)")
    print("fixed parameters:", sorted(fix) if fix else "(none)")

    # ------------------------------------------------------
    # 4) optimization
    # ------------------------------------------------------
    method = args.method.lower().strip()
    if method in ("nm", "nelder", "nelder-mead", "nelder_mead"):
        method = "Nelder-Mead"

    res = optimize.minimize(
        objective,
        theta0,
        args=(T_fit, Ne_fit, args, fix),
        method=method,
        options={"maxiter": 3000, "xatol": 1e-7, "fatol": 1e-10, "disp": True},
    )

    theta_hat = np.asarray(res.x, dtype=float)
    p_hat = unpack_params(theta_hat, args, fix)

    print("\n=== Fit result ===")
    print(f"success: {res.success}  message: {res.message}")
    print(f"objective (sum of squares in log10): {res.fun:.6g}")
    print("Fitted parameters:")
    print(f"  Eg = {p_hat.Eg:.6g} eV")
    print(f"  ED = {p_hat.ED:.6g} eV")
    print(f"  EA = {p_hat.EA:.6g} eV")
    print(f"  ND = {p_hat.ND:.6g} cm^-3")
    print(f"  NA = {p_hat.NA:.6g} cm^-3")
    print("Fixed:", ", ".join(sorted(fix)) if fix else "(none)")

    # ------------------------------------------------------
    # 5) error estimation (cov/corr) on the FIT RANGE only
    # ------------------------------------------------------
    err = estimate_param_errors(theta_hat, T_fit, Ne_fit, args, fix, free_names)
    s2 = err["_sigma2"]["value"]
    dof = err["_sigma2"]["dof"]

    print("\n=== Uncertainty estimate (1-sigma) ===")
    print(f"Residual variance s^2 = {s2:.6g}  (dof={dof})  [in log10-space]")
    for nm in free_names:
        est = err[nm]["estimate"]
        se = err[nm]["stderr"]
        print(f"  {nm:10s} = {est: .6g} ± {se:.3g}")
    for nm in ("ND", "NA"):
        if nm in err:
            print(f"  {nm:10s} = {err[nm]['estimate']:.6g} ± {err[nm]['stderr']:.3g}  [cm^-3]")

    # ------------------------------------------------------
    # 6) diagnostics (cov/corr/eigen) + suggestions
    # ------------------------------------------------------
    Cov = err.get("_cov_theta", {}).get("matrix", None)
    Corr = err.get("_corr_theta", {}).get("matrix", None)
    JTJ = err.get("_JTJ", {}).get("matrix", None)

    if Cov is not None and Corr is not None and JTJ is not None and len(free_names) > 0 and np.all(np.isfinite(Cov)):
        Cov = np.asarray(Cov, dtype=float)
        Corr = np.asarray(Corr, dtype=float)
        JTJ = np.asarray(JTJ, dtype=float)

        evals_cov, evecs_cov = eigen_sorted_sym(Cov, descending=True)
        evals_JTJ, evecs_JTJ = eigen_sorted_sym(JTJ, descending=False)

        cond_JTJ = None
        if np.all(evals_JTJ > 0):
            cond_JTJ = float(evals_JTJ[-1] / evals_JTJ[0])

        theta_est = np.array([err[nm]["estimate"] for nm in free_names], dtype=float)
        theta_se = np.array([err[nm]["stderr"] for nm in free_names], dtype=float)

        np.set_printoptions(precision=4, suppress=True, linewidth=140)

        print("\n=== Covariance matrix (internal free params) ===")
        print("free params =", free_names)
        print(Cov)

        print("\n=== Correlation matrix (internal free params) ===")
        print("free params =", free_names)
        print(Corr)

        print("\n=== Eigen of Cov (large -> small): principal uncertainty directions ===")
        ev_summary = summarize_eigenvectors(
            evals_cov, evecs_cov, free_names,
            topk=min(4, len(free_names)),
            compk=min(5, len(free_names)),
        )
        for item in ev_summary:
            comps = ", ".join([f"{n}:{w:+.3f}" for n, w in item["components"]])
            print(f"  #{item['rank']}: eigenvalue={item['eigenvalue']:.6g}  components=({comps})")

        print("\n=== Eigen of J^T J (small -> large): poorly constrained directions first ===")
        print(evals_JTJ)
        if cond_JTJ is not None:
            print(f"cond(J^T J) ≈ {cond_JTJ:.3e} (larger = more ill-conditioned)")

        v_bad = evecs_JTJ[:, 0]
        order = np.argsort(np.abs(v_bad))[::-1]
        comps = ", ".join([f"{free_names[i]}:{v_bad[i]:+.3f}" for i in order[:min(5, len(order))]])
        print(f"worst (smallest-eigenvalue) direction: ({comps})")

        suggestions = propose_fix_candidates(
            free_names,
            theta_est,
            theta_se,
            Corr,
            evals_cov=evals_cov,
            evecs_cov=evecs_cov,
            corr_thr=0.95,
            relerr_thr=0.5,
            topn=3,
        )

        print("\n=== Suggested parameters to fix / constrain (heuristic) ===")
        print("※ 数値的に『決まりにくい』指標（大きい誤差/強相関/不確か方向への寄与）からの提案です。物理的妥当性・独立測定の有無で最終判断してください。")
        for k, s in enumerate(suggestions, 1):
            re_str = f"{s['relerr']:.3g}" if s.get("relerr") is not None else "NA"
            print(f"  [{k}] {s['param']}: estimate={s['estimate']:.6g}, stderr={s['stderr']:.3g}, relerr={re_str}, score={s['score']:.3g}")
            for r in s.get("reasons", [])[:6]:
                print(f"       - {r}")

        # save diagnostics
        diag = {
            "free_params": free_names,
            "theta_hat": [float(x) for x in theta_est],
            "stderr": [float(x) for x in theta_se],
            "cov": Cov.tolist(),
            "corr": Corr.tolist(),
            "sigma2": float(s2) if np.isfinite(s2) else None,
            "dof": int(dof),
            "eig_cov": {
                "eigenvalues_desc": [float(x) for x in evals_cov],
                "eigenvectors_desc": evecs_cov.tolist(),
                "summary": summarize_eigenvectors(
                    evals_cov, evecs_cov, free_names,
                    topk=min(6, len(free_names)),
                    compk=min(6, len(free_names)),
                ),
            },
            "eig_JTJ": {
                "eigenvalues_asc": [float(x) for x in evals_JTJ],
                "eigenvectors_asc": evecs_JTJ.tolist(),
                "cond": cond_JTJ,
                "worst_direction": [{"param": free_names[i], "weight": float(evecs_JTJ[i, 0])} for i in range(len(free_names))],
            },
        }

        if getattr(args, "diag_save", ""):
            try:
                save_params(diag, filename=args.diag_save)
            except Exception as ex:
                print(f"Warning: failed to save diagnostics json '{args.diag_save}': {ex}")

        if getattr(args, "suggest_save", ""):
            try:
                save_params({"suggestions": suggestions}, filename=args.suggest_save)
            except Exception as ex:
                print(f"Warning: failed to save suggestions json '{args.suggest_save}': {ex}")

    else:
        print("\n(診断) 共分散/相関の推定に必要な条件を満たさないため、診断出力をスキップしました。")
        print("       例: フィット点数が少ない / 全パラメータ固定 / dof<=0 / NaN など")
        suggestions = []

    # ------------------------------------------------------
    # 7) save fitted params (physical space)
    # ------------------------------------------------------
    out_params = {
        "Eg": float(p_hat.Eg),
        "ED": float(p_hat.ED),
        "EA": float(p_hat.EA),
        "ND": float(p_hat.ND),
        "NA": float(p_hat.NA),
        "me": float(p_hat.me),
        "mh": float(p_hat.mh),
        "Tfitmin": float(args.Tfitmin),
        "Tfitmax": float(args.Tfitmax),
        "fix": sorted(list(fix)),
    }
    if getattr(args, "save", ""):
        try:
            save_params(out_params, filename=args.save)
        except Exception as ex:
            print(f"Warning: failed to save '{args.save}': {ex}")

    # ------------------------------------------------------
    # 8) plot (use all observed points; band from fit-range cov)
    # ------------------------------------------------------
    n_fit_all = np.array([calc_at_T(float(t), p_hat, args)["n"] for t in T_obs], dtype=float)

    Ne_lo = None
    Ne_hi = None
    try:
        Cov2 = err.get("_cov_theta", {}).get("matrix", None)
        if Cov2 is not None and len(free_names) > 0 and np.all(np.isfinite(Cov2)):
            y0, ylo, yhi = prediction_band_log10_Ne(
                T_obs,
                theta_hat,
                args,
                fix,
                np.asarray(Cov2, dtype=float),
                nsigma=float(args.band_sigma),
            )
            if ylo is not None and yhi is not None:
                Ne_lo = 10.0 ** ylo
                Ne_hi = 10.0 ** yhi
    except Exception as ex:
        print(f"Warning: prediction band computation failed: {ex}")

    fig, ax = plt.subplots(figsize=(6.6, 5.0))
    order = np.argsort(T_obs)
    ax.semilogy(T_obs, Ne_obs, "o", label="data Ne")
    ax.semilogy(T_obs[order], n_fit_all[order], "-", label="fit Ne(T)")

    if Ne_lo is not None and Ne_hi is not None:
        ax.fill_between(T_obs[order], Ne_lo[order], Ne_hi[order], alpha=0.25, label=f"model ±{args.band_sigma}σ")

    ax.set_xlabel("Temperature (K)")
    ax.set_ylabel("Electron density Ne (cm$^{-3}$)")
    if hasattr(args, "Nmin") and hasattr(args, "Nmax"):
        ax.set_ylim(args.Nmin, args.Nmax)
    ax.grid(True, which="both", alpha=0.2)
    ax.legend()
    plt.title("Fit result: Ne(T)")
    fig.tight_layout()
    plt.show()
def main():
    args = initialize()
    if args.mode == "read":
        run_read(args)
    elif args.mode == "fit":
        run_fit(args)
    else:
        run_sim(args)


if __name__ == "__main__":
    main()

