#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
TFT IV analyzer (transfer/output) for extracting threshold voltage and mobility.

Features (refactored):
- Read transfer curve data from CSV or Excel (xlsx/xls).
- Modes:
  - read          : load and plot transfer/output curves (no fitting)
  - fit           : fit sqrt(Ids) vs Vgs in the saturation regime and report parameter errors
                   + add confidence/prediction bands (data estimation error) and plots
  - autofitrange  : automatically choose a fitting range (xfitmin/xfitmax) by scanning candidates
                   and scoring fit quality; then run 'fit' using the chosen range.

Model (sat. regime, long-channel approximation):
  Ids_sat ≈ (W*Cox/(2*L)) * μ_sat * (Vgs - Vth)^2
  sqrt(Ids) = A + B * Vgs    where  B = sqrt((W*Cox/(2*L)) * μ_sat),  Vth = -A/B

This script performs ordinary least squares (OLS) on y = sqrt(Ids) vs x = Vgs.

Notes:
- If Ids contains negative values (offset/noise), those points are excluded from sqrt fitting.
- Prediction/CI bands are computed in the sqrt(Ids) space; optional propagation to Ids is provided.

Author: (refactor based on user's original script)
"""

from __future__ import annotations

import argparse
import math
import os
import re
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# ----------------------------
# Constants
# ----------------------------
EPS0 = 8.854418782e-12  # F/m


# ----------------------------
# Utilities
# ----------------------------
_FLOAT_RE = re.compile(r"[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?")


def pfloat(value, default: Optional[float] = None) -> Optional[float]:
    """Parse a float from strings like 'Vds=10V', '10', '10.0 V', etc."""
    if value is None:
        return default
    if isinstance(value, (int, float, np.floating)):
        try:
            return float(value)
        except Exception:
            return default
    s = str(value).strip()
    m = _FLOAT_RE.search(s)
    if not m:
        return default
    try:
        return float(m.group(0))
    except Exception:
        return default


def is_excel(path: str) -> bool:
    ext = os.path.splitext(path)[1].lower()
    return ext in (".xlsx", ".xls", ".xlsm", ".xlsb", ".ods")


# ----------------------------
# Data containers
# ----------------------------
@dataclass
class TFTGeometry:
    """Device / oxide geometry for Cox and mobility extraction."""
    dg_m: float = 100.0e-9     # gate oxide thickness (m)
    er_g: float = 11.9         # relative permittivity of gate insulator
    W_m: float = 300.0e-6      # channel width (m)
    L_m: float = 50.0e-6       # channel length (m)

    @property
    def Cox_Fm2(self) -> float:
        return self.er_g * EPS0 / self.dg_m

    @property
    def beta(self) -> float:
        """beta = W*Cox/(2L) in SI (A/V^2 if μ in m^2/V/s)."""
        return self.W_m * self.Cox_Fm2 / (2.0 * self.L_m)


@dataclass
class TFTData:
    Vgs: np.ndarray                 # (nVgs,)
    Vds: np.ndarray                 # (nVds,)
    Ids: np.ndarray                 # (nVgs, nVds)
    columns: List[str]              # length nVds, original column names

    def select_vds_column(self, vds0: float, tol: float = 1e-3, strategy: str = "ge_then_nearest") -> int:
        """
        Pick a Vds column index for a target Vds0.
        strategy:
          - 'ge_then_nearest': first Vds >= vds0-tol; if none, choose nearest
          - 'nearest'        : choose nearest
        """
        v = np.asarray(self.Vds, dtype=float)
        if v.size == 0:
            raise ValueError("No Vds columns found.")
        if strategy == "nearest":
            return int(np.argmin(np.abs(v - vds0)))
        # ge_then_nearest
        idx = np.where(v >= (vds0 - tol))[0]
        if idx.size > 0:
            return int(idx[0])
        return int(np.argmin(np.abs(v - vds0)))


# ----------------------------
# I/O
# ----------------------------
def load_transfer_table(path: str, sheet: Optional[str] = None) -> pd.DataFrame:
    """
    Load the transfer curve table.

    Expected wide format:
      col0: Vgs
      col1..: Ids at different Vds (column name contains Vds value, e.g. 'Vds=10', '10V', etc.)
    """
    if is_excel(path):
        df = pd.read_excel(path, sheet_name=sheet)
    else:
        # try common CSV encodings automatically
        try:
            df = pd.read_csv(path)
        except UnicodeDecodeError:
            df = pd.read_csv(path, encoding="cp932")
    if df.shape[1] < 2:
        raise ValueError("Input must have at least 2 columns: Vgs and one Ids(Vds) column.")
    return df


def parse_transfer_df(df: pd.DataFrame) -> TFTData:
    # First column = Vgs
    vgs = pd.to_numeric(df.iloc[:, 0], errors="coerce").to_numpy(dtype=float)
    # Remaining columns = Ids for each Vds
    colnames = [str(c) for c in df.columns[1:]]
    vds_vals = []
    for c in colnames:
        vds_vals.append(pfloat(c, default=np.nan))
    vds = np.asarray(vds_vals, dtype=float)

    ids = df.iloc[:, 1:].apply(pd.to_numeric, errors="coerce").to_numpy(dtype=float)

    # Drop rows where Vgs is nan
    mask = np.isfinite(vgs)
    vgs = vgs[mask]
    ids = ids[mask, :]

    # If some Vds could not be parsed, fall back to sequential indices (still keep column names)
    if np.any(~np.isfinite(vds)):
        # Try to parse from first row if it looks like header-style row
        # (Not changing the core format; just a gentle fallback.)
        vds = np.array([i for i in range(len(colnames))], dtype=float)

    return TFTData(Vgs=vgs, Vds=vds, Ids=ids, columns=colnames)


# ----------------------------
# Regression + uncertainty
# ----------------------------
@dataclass
class OLSResult:
    a: float
    b: float
    sa: float
    sb: float
    s: float                    # residual std (sqrt of unbiased variance)
    r: float                    # correlation coefficient
    dof: int
    xbar: float
    Sxx: float
    cov_ab: float

    def predict(self, x: np.ndarray) -> np.ndarray:
        return self.a + self.b * x

    def se_mean(self, x: np.ndarray) -> np.ndarray:
        # standard error of mean response
        return self.s * np.sqrt(1.0 / (self.dof + 2) + (x - self.xbar) ** 2 / self.Sxx)

    def se_pred(self, x: np.ndarray) -> np.ndarray:
        # standard error of a new observation
        return self.s * np.sqrt(1.0 + 1.0 / (self.dof + 2) + (x - self.xbar) ** 2 / self.Sxx)


def ols_fit(x: np.ndarray, y: np.ndarray) -> OLSResult:
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    if x.size != y.size:
        raise ValueError("x and y must have the same length.")
    if x.size < 3:
        raise ValueError("Need at least 3 points for OLS with dof=n-2.")
    xbar = float(np.mean(x))
    ybar = float(np.mean(y))
    dx = x - xbar
    dy = y - ybar
    Sxx = float(np.sum(dx * dx))
    if Sxx <= 0:
        raise ValueError("Sxx is zero; x values may be constant.")
    Sxy = float(np.sum(dx * dy))
    b = Sxy / Sxx
    a = ybar - b * xbar
    yhat = a + b * x
    resid = y - yhat
    dof = int(x.size - 2)
    s2 = float(np.sum(resid * resid) / dof)
    s = math.sqrt(max(s2, 0.0))
    sb = s / math.sqrt(Sxx)
    sa = s * math.sqrt(1.0 / x.size + xbar * xbar / Sxx)
    r = float(np.corrcoef(x, y)[0, 1]) if x.size > 1 else float("nan")
    cov_ab = -xbar * s2 / Sxx
    return OLSResult(a=a, b=b, sa=sa, sb=sb, s=s, r=r, dof=dof, xbar=xbar, Sxx=Sxx, cov_ab=cov_ab)


def t_critical_95(dof: int) -> float:
    # Two-sided 95% (alpha=0.05): t_{0.975, dof}
    try:
        from scipy.stats import t
        return float(t.ppf(0.975, dof))
    except Exception:
        # Normal approximation
        return 1.96


@dataclass
class FitResult:
    vds_target: float
    vds_used: float
    vds_col: int
    xfitmin: float
    xfitmax: float
    n_used: int

    ols: OLSResult
    Vth: float
    mu_sat: float
    sVth: float
    smu_sat: float

    x: np.ndarray
    y: np.ndarray
    yhat: np.ndarray
    resid: np.ndarray


def extract_fit(
    data: TFTData,
    geom: TFTGeometry,
    vds0: float,
    xfitmin: float,
    xfitmax: float,
    vds_strategy: str = "ge_then_nearest",
    exclude_negative_ids: bool = True,
) -> FitResult:
    col = data.select_vds_column(vds0, strategy=vds_strategy)
    vds_used = float(data.Vds[col])

    Vgs = data.Vgs
    Ids = data.Ids[:, col]

    # Filter by Vgs range
    m_range = (Vgs >= xfitmin) & (Vgs <= xfitmax) & np.isfinite(Vgs) & np.isfinite(Ids)
    if exclude_negative_ids:
        m_range &= (Ids > 0)

    x = Vgs[m_range]
    y = np.sqrt(Ids[m_range])

    if x.size < 3:
        raise ValueError(
            f"Not enough valid points for fitting. "
            f"Check xfit range and Ids sign. valid_points={x.size}"
        )

    ols = ols_fit(x, y)

    # Parameters
    if ols.b == 0:
        raise ValueError("Fit slope b is zero; cannot compute Vth and mobility.")
    Vth = -ols.a / ols.b

    # μ_sat from b: b^2 = beta * μ_sat  => μ_sat = b^2 / beta
    beta = geom.beta
    mu_sat = (ols.b ** 2) / beta

    # Error propagation (same as original spirit; include covariance for Vth)
    # Vth = -a/b
    # dVth/da = -1/b; dVth/db = a/b^2
    dV_da = -1.0 / ols.b
    dV_db = ols.a / (ols.b ** 2)
    var_Vth = (dV_da ** 2) * (ols.sa ** 2) + (dV_db ** 2) * (ols.sb ** 2) + 2.0 * dV_da * dV_db * ols.cov_ab
    sVth = math.sqrt(max(var_Vth, 0.0))

    # μ = b^2 / beta => dμ/db = 2b/beta
    dmu_db = 2.0 * ols.b / beta
    smu_sat = abs(dmu_db) * ols.sb

    yhat = ols.predict(x)
    resid = y - yhat

    return FitResult(
        vds_target=vds0,
        vds_used=vds_used,
        vds_col=col,
        xfitmin=xfitmin,
        xfitmax=xfitmax,
        n_used=int(x.size),
        ols=ols,
        Vth=Vth,
        mu_sat=mu_sat,
        sVth=sVth,
        smu_sat=smu_sat,
        x=x,
        y=y,
        yhat=yhat,
        resid=resid,
    )


# ----------------------------
# Auto fit range
# ----------------------------
@dataclass
class RangeScanPoint:
    xfitmin: float
    xfitmax: float
    n: int
    r: float
    s: float
    Vth: float
    mu: float


def scan_fit_ranges(
    data: TFTData,
    geom: TFTGeometry,
    vds0: float,
    x_min_candidates: np.ndarray,
    x_max_candidates: np.ndarray,
    min_points: int = 6,
) -> List[RangeScanPoint]:
    pts: List[RangeScanPoint] = []
    for xmin in x_min_candidates:
        for xmax in x_max_candidates:
            if xmax <= xmin:
                continue
            try:
                fr = extract_fit(data, geom, vds0=vds0, xfitmin=float(xmin), xfitmax=float(xmax))
            except Exception:
                continue
            if fr.n_used < min_points:
                continue
            pts.append(RangeScanPoint(
                xfitmin=float(xmin),
                xfitmax=float(xmax),
                n=fr.n_used,
                r=fr.ols.r,
                s=fr.ols.s,
                Vth=fr.Vth,
                mu=fr.mu_sat,
            ))
    return pts


def choose_best_range(points: List[RangeScanPoint]) -> Tuple[float, float, Dict[str, float]]:
    """
    Choose a best range via a simple score:
      score = (r^2) / (s + eps) * log(n)
    Then take the maximum score. Also report some diagnostics.
    """
    if not points:
        raise ValueError("No valid fitting ranges found in scan.")
    eps = 1e-30
    scores = []
    for p in points:
        rr = (p.r ** 2) if np.isfinite(p.r) else 0.0
        scores.append(rr / (p.s + eps) * math.log(max(p.n, 3)))
    idx = int(np.argmax(scores))
    best = points[idx]
    diag = {
        "score": float(scores[idx]),
        "r": float(best.r),
        "s": float(best.s),
        "n": float(best.n),
        "Vth": float(best.Vth),
        "mu": float(best.mu),
    }
    return best.xfitmin, best.xfitmax, diag


# ----------------------------
# Plotting
# ----------------------------
def plot_transfer_output(data: TFTData, title: str = "", show: bool = True, save: Optional[str] = None) -> None:
    Vgs = data.Vgs
    Ids = data.Ids

    # Transfer curves (Ids vs Vgs) for each Vds
    fig1 = plt.figure()
    ax1 = fig1.add_subplot(111)
    for j in range(Ids.shape[1]):
        ax1.plot(Vgs, np.abs(Ids[:, j]), label=f"{data.columns[j]}")
    ax1.set_yscale("log")
    ax1.set_xlabel("Vgs (V)")
    ax1.set_ylabel("|Ids| (A)")
    ax1.set_title(title + " Transfer characteristics")
    ax1.grid(True, which="both", ls=":")

    # Output characteristics (Ids vs Vds) at each Vgs (if possible)
    fig2 = plt.figure()
    ax2 = fig2.add_subplot(111)
    Vds = data.Vds
    for i, vg in enumerate(Vgs):
        ax2.plot(Vds, Ids[i, :], label=f"Vgs={vg:g}")
    ax2.set_xlabel("Vds (V)")
    ax2.set_ylabel("Ids (A)")
    ax2.set_title(title + " Output characteristics")
    ax2.grid(True, ls=":")

    if save:
        fig1.savefig(save.replace(".png", "_transfer.png"), dpi=200, bbox_inches="tight")
        fig2.savefig(save.replace(".png", "_output.png"), dpi=200, bbox_inches="tight")
    if show:
        plt.show()
    else:
        plt.close(fig1)
        plt.close(fig2)


def plot_fit_with_bands(
    data: TFTData,
    fit: FitResult,
    title: str = "",
    alpha: float = 0.05,
    n_grid: int = 200,
    show: bool = True,
    save: Optional[str] = None,
) -> None:
    # Prepare grid on the fitting interval
    xg = np.linspace(fit.xfitmin, fit.xfitmax, n_grid)
    yhat = fit.ols.predict(xg)

    tcrit = t_critical_95(fit.ols.dof)

    se_mean = fit.ols.se_mean(xg)
    se_pred = fit.ols.se_pred(xg)

    y_mean_lo = yhat - tcrit * se_mean
    y_mean_hi = yhat + tcrit * se_mean
    y_pred_lo = yhat - tcrit * se_pred
    y_pred_hi = yhat + tcrit * se_pred

    # Original column data for this Vds (for context)
    Vgs_all = data.Vgs
    Ids_all = data.Ids[:, fit.vds_col]
    m_pos = np.isfinite(Vgs_all) & np.isfinite(Ids_all) & (Ids_all > 0)
    Vgs_pos = Vgs_all[m_pos]
    sqrtIds_pos = np.sqrt(Ids_all[m_pos])

    # --- sqrt plot with bands
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(Vgs_pos, sqrtIds_pos, "o", ms=4, label="data (Ids>0)")
    ax.plot(xg, yhat, "-", lw=2, label="fit")

    ax.fill_between(xg, y_mean_lo, y_mean_hi, alpha=0.25, label="95% CI (mean)")
    ax.fill_between(xg, y_pred_lo, y_pred_hi, alpha=0.15, label="95% PI (obs)")

    ax.axvline(fit.xfitmin, ls="--")
    ax.axvline(fit.xfitmax, ls="--")

    ax.set_xlabel("Vgs (V)")
    ax.set_ylabel("sqrt(Ids) (A^0.5)")
    ax.set_title(title + f" sqrt(Ids)-Vgs fit (Vds~{fit.vds_used:g} V)")
    ax.grid(True, ls=":")
    ax.legend()

    # --- residual plot
    fig2 = plt.figure()
    ax2 = fig2.add_subplot(111)
    ax2.axhline(0.0, lw=1)
    ax2.plot(fit.x, fit.resid, "o", ms=4)
    ax2.set_xlabel("Vgs (V)")
    ax2.set_ylabel("residual (sqrt(Ids))")
    ax2.set_title(title + " Residuals on fit range")
    ax2.grid(True, ls=":")

    if save:
        fig.savefig(save.replace(".png", "_sqrtfit.png"), dpi=200, bbox_inches="tight")
        fig2.savefig(save.replace(".png", "_residual.png"), dpi=200, bbox_inches="tight")
    if show:
        plt.show()
    else:
        plt.close(fig)
        plt.close(fig2)


def plot_autorange_summary(points: List[RangeScanPoint], chosen: Tuple[float, float], title: str = "", show: bool = True, save: Optional[str] = None) -> None:
    xmin, xmax = chosen
    xfitmaxs = np.array([p.xfitmax for p in points], dtype=float)
    mus = np.array([p.mu for p in points], dtype=float)
    vths = np.array([p.Vth for p in points], dtype=float)
    rs = np.array([p.r for p in points], dtype=float)
    ss = np.array([p.s for p in points], dtype=float)

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(xfitmaxs, mus, "o", ms=4, label="mu_sat")
    ax.set_xlabel("xfitmax (V)")
    ax.set_ylabel("mu_sat (m^2/V/s)")
    ax.set_title(title + " Range scan: mu_sat vs xfitmax (mixed xmin)")
    ax.grid(True, ls=":")
    ax.axvline(xmax, ls="--", label="chosen xfitmax")
    ax.legend()

    fig2 = plt.figure()
    ax2 = fig2.add_subplot(111)
    ax2.plot(xfitmaxs, vths, "o", ms=4, label="Vth")
    ax2.set_xlabel("xfitmax (V)")
    ax2.set_ylabel("Vth (V)")
    ax2.set_title(title + " Range scan: Vth vs xfitmax (mixed xmin)")
    ax2.grid(True, ls=":")
    ax2.axvline(xmax, ls="--", label="chosen xfitmax")
    ax2.legend()

    fig3 = plt.figure()
    ax3 = fig3.add_subplot(111)
    ax3.plot(xfitmaxs, rs, "o", ms=4, label="r")
    ax3.set_xlabel("xfitmax (V)")
    ax3.set_ylabel("correlation r")
    ax3.set_title(title + " Range scan: r vs xfitmax")
    ax3.grid(True, ls=":")
    ax3.axvline(xmax, ls="--", label="chosen xfitmax")
    ax3.legend()

    fig4 = plt.figure()
    ax4 = fig4.add_subplot(111)
    ax4.plot(xfitmaxs, ss, "o", ms=4, label="s (residual std)")
    ax4.set_xlabel("xfitmax (V)")
    ax4.set_ylabel("s (sqrt(Ids))")
    ax4.set_title(title + " Range scan: residual std vs xfitmax")
    ax4.grid(True, ls=":")
    ax4.axvline(xmax, ls="--", label="chosen xfitmax")
    ax4.legend()

    if save:
        fig.savefig(save.replace(".png", "_scan_mu.png"), dpi=200, bbox_inches="tight")
        fig2.savefig(save.replace(".png", "_scan_vth.png"), dpi=200, bbox_inches="tight")
        fig3.savefig(save.replace(".png", "_scan_r.png"), dpi=200, bbox_inches="tight")
        fig4.savefig(save.replace(".png", "_scan_s.png"), dpi=200, bbox_inches="tight")

    if show:
        plt.show()
    else:
        plt.close(fig); plt.close(fig2); plt.close(fig3); plt.close(fig4)


# ----------------------------
# CLI
# ----------------------------
def build_argparser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description="TFT transfer/output analyzer: Vth and mobility extraction with error estimation.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    p.add_argument("--mode", choices=["read", "fit", "autofitrange"], default="fit",
                   help="Operation mode.")
    p.add_argument("-i", "--input", default="TransferCurve.csv", help="Input file path (CSV or Excel).")
    p.add_argument("--sheet", default=None, help="Excel sheet name (Excel only).")
    p.add_argument("--title", default="", help="Figure title prefix.")
    p.add_argument("--savefig", default=None, help="If set, save figures with this base filename (e.g. out.png).")
    p.add_argument("--noshow", action="store_true", help="Do not show figures (useful with --savefig).")

    # device params
    p.add_argument("--dg", type=float, default=100.0e-9, help="Gate oxide thickness (m).")
    p.add_argument("--er", type=float, default=11.9, help="Gate insulator relative permittivity.")
    p.add_argument("--W", type=float, default=300.0e-6, help="Channel width (m).")
    p.add_argument("--L", type=float, default=50.0e-6, help="Channel length (m).")

    # fit params
    p.add_argument("--vds0", type=float, default=10.0, help="Target Vds used for sqrt(Ids)-Vgs fitting.")
    p.add_argument("--xfitmin", type=float, default=6.0, help="Minimum Vgs for fitting.")
    p.add_argument("--xfitmax", type=float, default=10.0, help="Maximum Vgs for fitting.")
    p.add_argument("--vds-strategy", choices=["ge_then_nearest", "nearest"], default="ge_then_nearest",
                   help="How to choose the Vds column for the target vds0.")
    p.add_argument("--keep-negative-ids", action="store_true",
                   help="Do NOT exclude Ids<=0 points in sqrt fitting (may crash if negative).")

    # auto range scan params
    p.add_argument("--scan-xmin-step", type=float, default=1.0, help="Vgs step for xfitmin candidates.")
    p.add_argument("--scan-xmax-step", type=float, default=1.0, help="Vgs step for xfitmax candidates.")
    p.add_argument("--scan-min-points", type=int, default=6, help="Minimum points required in a candidate range.")

    return p


def main() -> None:
    args = build_argparser().parse_args()
    show = not args.noshow

    geom = TFTGeometry(dg_m=args.dg, er_g=args.er, W_m=args.W, L_m=args.L)

    df = load_transfer_table(args.input, sheet=args.sheet)
    data = parse_transfer_df(df)

    if args.mode == "read":
        plot_transfer_output(data, title=args.title, show=show, save=args.savefig)
        return

    if args.mode == "fit":
        fr = extract_fit(
            data=data,
            geom=geom,
            vds0=args.vds0,
            xfitmin=args.xfitmin,
            xfitmax=args.xfitmax,
            vds_strategy=args.vds_strategy,
            exclude_negative_ids=not args.keep_negative_ids,
        )

        # report
        print("=== FIT RESULT ===")
        print(f"Input            : {args.input}")
        if args.sheet is not None:
            print(f"Sheet            : {args.sheet}")
        print(f"Vds target       : {fr.vds_target:g} V")
        print(f"Vds used         : {fr.vds_used:g} V (col={fr.vds_col}, name='{data.columns[fr.vds_col]}')")
        print(f"Fit range        : {fr.xfitmin:g} to {fr.xfitmax:g} V")
        print(f"Valid points     : n = {fr.n_used:d} (Ids>0 and finite in range)")
        print("")
        print("OLS on y=sqrt(Ids)=a+b*Vgs")
        print(f"  a  = {fr.ols.a:.6g}  ± {fr.ols.sa:.2g}")
        print(f"  b  = {fr.ols.b:.6g}  ± {fr.ols.sb:.2g}")
        print(f"  r  = {fr.ols.r:.6g}")
        print(f"  s  = {fr.ols.s:.6g}   (residual std, dof={fr.ols.dof})")
        print("")
        print("Extracted parameters")
        print(f"  Vth    = {fr.Vth:.6g}  ± {fr.sVth:.2g}   (V)")
        print(f"  mu_sat = {fr.mu_sat:.6g}  ± {fr.smu_sat:.2g}   (m^2/V/s)")
        print("")
        print("Device geometry")
        print(f"  Cox    = {geom.Cox_Fm2:.6g}  (F/m^2)")
        print(f"  beta   = W*Cox/(2L) = {geom.beta:.6g}")

        plot_transfer_output(data, title=args.title, show=show, save=args.savefig)
        plot_fit_with_bands(data, fr, title=args.title, show=show, save=args.savefig)
        return

    # autofitrange
    Vgs_min = float(np.nanmin(data.Vgs))
    Vgs_max = float(np.nanmax(data.Vgs))
    xmin_candidates = np.arange(Vgs_min, Vgs_max, args.scan_xmin_step)
    xmax_candidates = np.arange(Vgs_min + args.scan_xmax_step, Vgs_max + 1e-12, args.scan_xmax_step)

    points = scan_fit_ranges(
        data=data,
        geom=geom,
        vds0=args.vds0,
        x_min_candidates=xmin_candidates,
        x_max_candidates=xmax_candidates,
        min_points=args.scan_min_points,
    )

    xmin_best, xmax_best, diag = choose_best_range(points)
    print("=== AUTO FIT RANGE ===")
    print(f"Best range: xfitmin={xmin_best:g} V, xfitmax={xmax_best:g} V")
    print(f"Diagnostics: score={diag['score']:.4g}, r={diag['r']:.4g}, s={diag['s']:.4g}, n={int(diag['n'])}")
    print(f"            Vth≈{diag['Vth']:.6g} V, mu≈{diag['mu']:.6g} m^2/V/s")
    print("Now running fit with the chosen range...")

    plot_autorange_summary(points, (xmin_best, xmax_best), title=args.title, show=show, save=args.savefig)

    fr = extract_fit(
        data=data,
        geom=geom,
        vds0=args.vds0,
        xfitmin=xmin_best,
        xfitmax=xmax_best,
        vds_strategy=args.vds_strategy,
        exclude_negative_ids=not args.keep_negative_ids,
    )

    # report (same as fit mode)
    print("=== FIT RESULT (AUTO RANGE) ===")
    print(f"Vds used         : {fr.vds_used:g} V (col={fr.vds_col}, name='{data.columns[fr.vds_col]}')")
    print(f"Fit range        : {fr.xfitmin:g} to {fr.xfitmax:g} V")
    print(f"Valid points     : n = {fr.n_used:d}")
    print(f"Vth    = {fr.Vth:.6g}  ± {fr.sVth:.2g}   (V)")
    print(f"mu_sat = {fr.mu_sat:.6g}  ± {fr.smu_sat:.2g}   (m^2/V/s)")

    plot_transfer_output(data, title=args.title, show=show, save=args.savefig)
    plot_fit_with_bands(data, fr, title=args.title, show=show, save=args.savefig)


if __name__ == "__main__":
    main()
