#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import argparse
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

# skimage.measure is imported lazily in plot_3d_fermi_surface to avoid crash if not installed

NORMALIZATION_UNIT = 2 * np.pi  # k_int * 2π -> k_phys (a=1)


# =============================================================================
# Utilities
# =============================================================================

def reduce_to_first_bz(k_int: float) -> float:
    """k_int を第1BZ [-0.5, 0.5) に還元（周期1）"""
    return ((k_int + 0.5) % 1.0) - 0.5


def build_g_vecs(dim: int) -> np.ndarray:
    """dim 次元の最小セットのG: 0 と ±e_i を物理単位で返す"""
    ints = [np.zeros(dim, dtype=float)]
    for i in range(dim):
        v = np.zeros(dim, dtype=float)
        v[i] = 1.0
        ints.append(v.copy())
        ints.append(-v.copy())
    return np.array(ints, dtype=float) * NORMALIZATION_UNIT


def nband_of_model(dim: int, mode: str) -> int:
    """このtoyモデルでのバンド本数（スピンは含めない）"""
    if mode == "tb":
        return 1
    if mode in ("free", "nfe"):
        return 1 + 2 * dim  # G = 0, ±e_i
    raise ValueError("bad mode")


def determine_energy_range(E_samples: np.ndarray, E_range_user=None):
    """
    エネルギー軸レンジを決める
    - ユーザー指定があればそれを採用
    - なければサンプルmin/maxにパッド
    """
    if E_range_user is not None:
        Emin, Emax = map(float, E_range_user)
        return Emin, Emax

    Emin = float(np.min(E_samples))
    Emax = float(np.max(E_samples))
    pad = 0.05 * (Emax - Emin if Emax > Emin else 1.0)
    return Emin - pad, Emax + pad


# =============================================================================
# Band energies at k
# =============================================================================

def get_band_energies_at_k_dim(k_vec_int: np.ndarray, model: str, dim: int, args) -> np.ndarray:
    """
    k_vec_int: (dim,) internal unit (2π/a normalized)
    returns: eigenvalues (Nband,)
    """
    k_phys = k_vec_int * NORMALIZATION_UNIT
    G_VECS = build_g_vecs(dim)

    if model == "tb":
        return np.array([tb_energy(k_phys, dim)], dtype=float)

    if model == "free":
        # 自由電子モデル: 行列は作らず、各Gに対してエネルギーを直接計算
        energies = []
        for g in G_VECS:
            dk = k_phys - g
            # E = |k - G|^2 (定数係数は省略またはNORMALIZATION_UNITに内包)
            energies.append(np.sum(dk**2))
        return np.sort(np.array(energies, dtype=float))

    if model == "nfe":
        # ほぼ自由電子モデル: 逆格子ベクトル間のカップリングを考慮
        num_G = len(G_VECS)
        H = np.zeros((num_G, num_G), dtype=float)

        # 対角成分: |k-G|^2
        for i in range(num_G):
            dk = k_phys - G_VECS[i]
            H[i, i] = float(np.sum(dk**2))

        # 非対角成分: 周期ポテンシャル V_G による散乱
        for i in range(num_G):
            for j in range(i + 1, num_G):
                H[i, j] = args.v
                H[j, i] = args.v

        return np.linalg.eigvalsh(H)

    raise ValueError(f"Unknown model: {model}")
def tb_energy(k_phys: np.ndarray, dim: int) -> float:
    """
    dim次元TB (t=1, onsite=0)
    """
    if dim == 1:
        return -float(np.cos(k_phys[0]))
    elif dim == 2:
        return -float(np.cos(k_phys[0]) + np.cos(k_phys[1]))
    elif dim == 3:
        return -float(np.cos(k_phys[0]) + np.cos(k_phys[1]) + np.cos(k_phys[2]))
    else:
        raise ValueError("dim must be 1, 2, or 3")


# =============================================================================
# BZ sampling -> DOS -> EF(nelec)
# =============================================================================

def sample_energies_in_bz(args, dim: int) -> np.ndarray:
    """
    dim 次元 BZ を格子サンプルして、全バンドのEを1次元配列で返す（状態列）
    """
    n = args.ef_res
    k = np.linspace(-0.5, 0.5, n, endpoint=False)

    if dim == 1:
        pts = k[:, None]
    elif dim == 2:
        KX, KY = np.meshgrid(k, k, indexing="ij")
        pts = np.stack([KX.ravel(), KY.ravel()], axis=1)
    elif dim == 3:
        KX, KY, KZ = np.meshgrid(k, k, k, indexing="ij")
        pts = np.stack([KX.ravel(), KY.ravel(), KZ.ravel()], axis=1)
    else:
        raise ValueError("dim must be 1,2,3")

    if args.mode == "tb":
        k_phys = pts * NORMALIZATION_UNIT
        if dim == 1:
            E = -np.cos(k_phys[:, 0])
        elif dim == 2:
            E = -(np.cos(k_phys[:, 0]) + np.cos(k_phys[:, 1]))
        elif dim == 3:
            E = -(np.cos(k_phys[:, 0]) + np.cos(k_phys[:, 1]) + np.cos(k_phys[:, 2]))
        else:
            raise ValueError("dim must be 1,2,3")
        return np.asarray(E, dtype=float)

    evals_all = []
    for kv in pts:
        evals_all.append(get_band_energies_at_k_dim(kv, args.mode, dim, args))
    return np.asarray(evals_all, dtype=float).ravel()


def compute_dos(E_samples: np.ndarray, nE: int, sigma: float, E_range, scale_states: float):
    """
    DOS をガウスブロードニングで作る
    """
    E_samples = np.asarray(E_samples, dtype=float)
    Emin, Emax = map(float, E_range)
    E_grid = np.linspace(Emin, Emax, nE)

    if sigma is None or sigma <= 0:
        hist, edges = np.histogram(E_samples, bins=nE, range=(Emin, Emax), density=True)
        centers = 0.5 * (edges[:-1] + edges[1:])
        return centers, hist * scale_states

    inv = 1.0 / (np.sqrt(2.0 * np.pi) * sigma)
    dos = np.zeros_like(E_grid, dtype=float)

    block = 20000
    for i in range(0, len(E_samples), block):
        Ei = E_samples[i:i + block]
        x = (E_grid[:, None] - Ei[None, :]) / sigma
        dos += inv * np.exp(-0.5 * x * x).sum(axis=1)

    dos /= len(E_samples)       # ∫dos dE ≈ 1
    dos *= scale_states         # ∫dos dE ≈ scale_states
    return E_grid, dos


def ef_from_dos(nelec: float, E_grid: np.ndarray, dos: np.ndarray, warn: bool = True) -> float:
    """
    DOS を積分して N(E)=∫DOS dE を作り、N(EF)=nelec となる EF を内挿で求める。
    """
    E_grid = np.asarray(E_grid, dtype=float)
    dos = np.asarray(dos, dtype=float)

    dE = np.diff(E_grid)
    avg = 0.5 * (dos[:-1] + dos[1:])
    Ncum = np.concatenate([[0.0], np.cumsum(avg * dE)])

    # clamp target
    nelec_clamped = float(np.clip(nelec, Ncum[0], Ncum[-1]))

    if warn:
        if nelec < Ncum[0] - 1e-12:
            print(f"[WARN] nelec={nelec} is below integrated minimum ({Ncum[0]:.6g}). "
                  f"EF is clamped to Emin={E_grid[0]:.6g}.")
        if nelec > Ncum[-1] + 1e-12:
            print(f"[WARN] nelec={nelec} exceeds integrated states in DOS range ({Ncum[-1]:.6g}). "
                  f"EF is clamped to Emax={E_grid[-1]:.6g}. Consider widening --E_ef_range or omit it.")

    EF = float(np.interp(nelec_clamped, Ncum, E_grid))
    return EF


# =============================================================================
# kF along 1D path and root solving
# =============================================================================

def refine_root_bisection(func, a, b, tol=1e-10, max_iter=80):
    fa = func(a)
    fb = func(b)
    if fa == 0.0:
        return a
    if fb == 0.0:
        return b
    if fa * fb > 0:
        return None

    lo, hi = a, b
    flo = fa
    for _ in range(max_iter):
        mid = 0.5 * (lo + hi)
        fmid = func(mid)
        if abs(fmid) < tol or abs(hi - lo) < tol:
            return mid
        if flo * fmid <= 0:
            hi = mid
        else:
            lo = mid
            flo = fmid
    return 0.5 * (lo + hi)


def find_kf_crossings_1d(k_grid_int, energies_1d, EF, dim, args):
    """1D直線パスでのkF点（直線プロット用）"""
    Nk, Nb = energies_1d.shape
    crossings = []

    def get_Ek_band(kx_int, ib):
        k_vec = np.zeros(dim, dtype=float)
        k_vec[0] = kx_int
        return float(get_band_energies_at_k_dim(k_vec, args.mode, dim, args)[ib])

    for ib in range(Nb):
        Ek = energies_1d[:, ib] - EF
        for i in range(Nk - 1):
            a = k_grid_int[i]
            b = k_grid_int[i + 1]
            fa = Ek[i]
            fb = Ek[i + 1]
            if fa == 0.0:
                crossings.append((ib, a))
                continue
            if fa * fb > 0:
                continue
            root = refine_root_bisection(lambda x: get_Ek_band(x, ib) - EF, a, b, tol=args.kf_tol)
            if root is not None:
                crossings.append((ib, root))

    # reduce to 1st BZ and merge per band
    reduced = [(ib, reduce_to_first_bz(kx)) for (ib, kx) in crossings]
    reduced.sort(key=lambda x: (x[0], x[1]))

    merged = []
    for ib, kx in reduced:
        if not merged:
            merged.append([ib, kx])
        else:
            ib0, k0 = merged[-1]
            if ib == ib0 and abs(kx - k0) < args.kf_merge_tol:
                merged[-1][1] = 0.5 * (k0 + kx)
            else:
                merged.append([ib, kx])

    return [(ib, kx) for ib, kx in merged]


# =============================================================================
# Effective mass overlay
# =============================================================================

def effective_mass_1d(k_int: np.ndarray, E: np.ndarray) -> np.ndarray:
    """
    1Dパス上で有効質量 m*(k) = 1 / (d^2E/d(ka)^2) を数値評価。
    """
    k_int = np.asarray(k_int, dtype=float)
    E = np.asarray(E, dtype=float)

    if len(k_int) < 3:
        return np.full_like(E, np.nan, dtype=float)

    dk = k_int[1] - k_int[0]
    d2E = np.full_like(E, np.nan, dtype=float)
    d2E[1:-1] = (E[2:] - 2.0 * E[1:-1] + E[:-2]) / (dk * dk)

    with np.errstate(divide="ignore", invalid="ignore"):
        mstar_kint = 1.0 / d2E

    mstar_ka = mstar_kint * (2.0 * np.pi) ** 2
    return mstar_ka


# =============================================================================
# Band + DOS plotter (Straight path only)
# =============================================================================

def plot_band(args, dim: int, EF: float, E_samples: np.ndarray, E_range_plot, E_range_ef, Egrid_dos=None, dos=None):
    """
    1Dバンドプロット、またはkx方向に沿った直線カット用 (DOS/m*対応)
    """
    k_min, k_max = args.k_path_range
    k_resolution = args.res * 3
    k_range_int = np.linspace(k_min, k_max, k_resolution)

    # 1D path energies
    all_energies = []
    for kx_int in k_range_int:
        k_vec = np.zeros(dim, dtype=float)
        k_vec[0] = kx_int
        all_energies.append(get_band_energies_at_k_dim(k_vec, args.mode, dim, args))
    all_energies = np.array(all_energies, dtype=float)  # (Nk, Nband)

    # kF roots on this cut
    kf_list = find_kf_crossings_1d(k_range_int, all_energies, EF, dim, args)

    Emin_plot, Emax_plot = E_range_plot

    # DOS plot setup
    is_dos_plot = args.dos and (Egrid_dos is not None)

    if is_dos_plot:
        Nband = nband_of_model(dim, args.mode)
        scale_states = float(args.spin) * Nband
        fig, (ax_band, ax_dos) = plt.subplots(1, 2, figsize=(11, 5), gridspec_kw={"wspace": 0.28})
    else:
        fig, ax_band = plt.subplots(figsize=(6.5, 5))
        ax_dos = None

    # --- band ---
    is_reduced_zone = (k_min >= -0.5 and k_max <= 0.5)

    for ib in range(all_energies.shape[1]):
        ax_band.plot(k_range_int, all_energies[:, ib], linewidth=1.5, alpha=0.9)

    if is_reduced_zone:
        ax_band.set_xlim(-0.5, 0.5)

    ax_band.set_ylim(Emin_plot, Emax_plot)
    ax_band.set_title(rf"{args.type} ({args.mode})   $E_F={EF:.6g}$")
    ax_band.axhline(y=EF, linestyle="--", linewidth=2, color="gray", label=r"$E_F$")

    # kF vertical lines (reduced) + learned label
    if kf_list:
        for ib, kx_red in kf_list:
            if k_min <= kx_red <= k_max:
                ax_band.axvline(x=kx_red, linestyle=":", linewidth=1.2)
        ax_band.plot([], [], linestyle=":", color="black", label=r"$k_F$")

    # effective mass overlay
    if args.mstar:
        ax_m = ax_band.twinx()
        for ib in range(all_energies.shape[1]):
            mstar = effective_mass_1d(k_range_int, all_energies[:, ib])
            mstar = np.clip(mstar, -5.0, 5.0)  # required range
            ax_m.plot(k_range_int, mstar, linestyle="--", linewidth=1.2, alpha=0.7, color="C1")

        ax_m.set_ylabel(r"Effective mass $m^*(k)$")
        ax_m.set_ylim(-5.0, 5.0)
        ax_m.axhline(0.0, color="gray", linewidth=0.8, alpha=0.6)
        ax_band.plot([], [], "C1--", label=r"$m^*(k)$")

    ax_band.set_xlabel(r"$k_x / (2\pi/a)$")
    ax_band.set_ylabel(r"Energy $E$")
    ax_band.grid(True, linestyle="--", alpha=0.5)
    ax_band.legend(loc="upper right")

    # --- DOS ---
    if ax_dos is not None:
        ax_dos.plot(dos, Egrid_dos, linewidth=1.6)
        ax_dos.axhline(EF, linestyle="--", linewidth=2, color="gray", label=r"$E_F$")

        # DOSの縦軸はバンド表示レンジに合わせる（見た目の統一）
        ax_dos.set_ylim(Emin_plot, Emax_plot)

        ax_dos.set_xlabel(r"DOS  (states / cell / energy)")
        ax_dos.set_ylabel(r"Energy $E$")
        ax_dos.set_title(rf"DOS ($\int$DOS dE = {scale_states:.3g})")
        ax_dos.grid(True, linestyle="--", alpha=0.5)
        ax_dos.legend(loc="upper right")

    plt.tight_layout()
    plt.show()


# =============================================================================
# High-symmetry K-Path Builder (for 2D/3D band plot)
# =============================================================================

def build_k_path(points, labels, res_per_segment=40):
    """
    高対称点(points)をつなぐパスを生成する
    """
    k_vecs_list = []
    x_dist_list = []
    x_nodes = []
    
    current_dist = 0.0
    x_nodes.append(current_dist)
    
    # 修正済み: 始点を配列化してリストに追加
    k_vecs_list.append(np.array(points[0]))
    x_dist_list.append(np.array([current_dist]))

    for i in range(len(points) - 1):
        start = np.array(points[i])
        end = np.array(points[i+1])
        
        dist = np.linalg.norm(end - start)
        
        # 点を生成（始点は重複するので含めず、終点を含める）
        segment_k = np.linspace(start, end, res_per_segment + 1)[1:]
        segment_dist = np.linspace(current_dist, current_dist + dist, res_per_segment + 1)[1:]
        
        k_vecs_list.append(segment_k)
        x_dist_list.append(segment_dist)
        
        current_dist += dist
        x_nodes.append(current_dist)

    k_vecs = np.vstack(k_vecs_list)
    x_dist = np.concatenate(x_dist_list)
    
    return k_vecs, x_dist, x_nodes, labels


def plot_band_with_path(args, dim: int, EF: float, E_range_plot):
    """
    高対称点パスに沿ってバンド図を描画する (2D/3D兼用)
    """
    
    # --- 1. 高対称点とパスの定義 ---
    if dim == 2:
        # 2D Square Lattice: Gamma -> X -> M -> Gamma
        G_pt = [0.0, 0.0]
        X_pt = [0.5, 0.0]
        M_pt = [0.5, 0.5]
        path_points = [G_pt, X_pt, M_pt, G_pt]
        path_labels = [r"$\Gamma$", r"$X$", r"$M$", r"$\Gamma$"]
        
    elif dim == 3:
        # 3D Simple Cubic: Gamma -> X -> M -> Gamma -> R -> X
        G_pt = [0.0, 0.0, 0.0]
        X_pt = [0.5, 0.0, 0.0]
        M_pt = [0.5, 0.5, 0.0]
        R_pt = [0.5, 0.5, 0.5]
        path_points = [G_pt, X_pt, M_pt, G_pt, R_pt, X_pt]
        path_labels = [r"$\Gamma$", r"$X$", r"$M$", r"$\Gamma$", r"$R$", r"$X$"]
    
    else:
        # 1D band plotはplot_bandを使うため、通常ここは呼ばれない
        print(f"[ERROR] Path plotting for dim={dim} is not supported.")
        return

    # --- 2. パス生成 ---
    k_vecs, x_coords, node_pos, node_labs = build_k_path(path_points, path_labels, res_per_segment=args.res)

    # --- 3. エネルギー計算 ---
    all_energies = []
    for kv in k_vecs:
        es = get_band_energies_at_k_dim(kv, args.mode, dim, args)
        all_energies.append(es)
        
    all_energies = np.array(all_energies, dtype=float) # (Nk, Nband)

    # --- 4. プロット ---
    fig, ax = plt.subplots(figsize=(9, 6))
    
    # バンド線
    for ib in range(all_energies.shape[1]):
        ax.plot(x_coords, all_energies[:, ib], linewidth=1.5, alpha=0.9, color="C0")

    # フェルミ準位
    ax.axhline(EF, linestyle="--", color="gray", linewidth=1.5, label=r"$E_F$")

    # 高対称点の縦線とラベル
    for x_pos in node_pos:
        ax.axvline(x_pos, color="black", linestyle="-", linewidth=0.5)
    
    ax.set_xticks(node_pos)
    ax.set_xticklabels(node_labs, fontsize=12)
    
    # グラフ設定
    Emin, Emax = E_range_plot
    ax.set_ylim(Emin, Emax)
    ax.set_xlim(x_coords[0], x_coords[-1])
    ax.set_ylabel("Energy")
    ax.set_title(rf"{dim}D Band Structure ({args.mode})   $E_F={EF:.4g}$")
    ax.grid(True, axis="y", linestyle="--", alpha=0.5)

    plt.tight_layout()
    plt.show()


# =============================================================================
# Fermi surface: build energy grid for FS plotting
# =============================================================================

# ... (build_energy_grid_for_fs 関数は省略、元のものをそのまま利用)
def build_energy_grid_for_fs(args, dim: int):
    # (中略：元のコードと同じ)
    N = args.res
    k_range_int = np.linspace(-0.5, 0.5, N)
    band_index = int(args.fs_band)

    if dim == 1:
        KX = k_range_int
        E = np.zeros_like(KX, dtype=float)
        for i, kx in enumerate(KX):
            kv = np.array([kx], dtype=float)
            evals = get_band_energies_at_k_dim(kv, args.mode, 1, args)
            E[i] = evals[band_index]
        return k_range_int, E

    if dim == 2:
        KX, KY = np.meshgrid(k_range_int, k_range_int, indexing="ij")
        E = np.zeros_like(KX, dtype=float)

        if args.mode == "tb":
            E = -(np.cos(KX * NORMALIZATION_UNIT) + np.cos(KY * NORMALIZATION_UNIT))
            return k_range_int, E

        if args.mode == "free":
            G = build_g_vecs(2)
            KX_phys = KX * NORMALIZATION_UNIT
            KY_phys = KY * NORMALIZATION_UNIT
            dists = []
            for g in G:
                d2 = (KX_phys - g[0]) ** 2 + (KY_phys - g[1]) ** 2
                dists.append(d2)
            all_bands = np.stack(dists, axis=0)
            all_bands.sort(axis=0)
            if band_index < len(G): E = all_bands[band_index]
            else: E = all_bands[-1]
            return k_range_int, E

        if args.mode == "nfe":
            it = np.nditer([KX, KY], flags=["multi_index"])
            for kx, ky in it:
                kv = np.array([float(kx), float(ky)], dtype=float)
                evals = get_band_energies_at_k_dim(kv, "nfe", 2, args)
                E[it.multi_index] = evals[band_index]
            return k_range_int, E

    if dim == 3:
        KX, KY, KZ = np.meshgrid(k_range_int, k_range_int, k_range_int, indexing="ij")
        E = np.zeros_like(KX, dtype=float)

        if args.mode == "tb":
            E = -(np.cos(KX * NORMALIZATION_UNIT) + np.cos(KY * NORMALIZATION_UNIT) + np.cos(KZ * NORMALIZATION_UNIT))
            return k_range_int, E

        if args.mode == "free":
            G = build_g_vecs(3)
            KX_phys = KX * NORMALIZATION_UNIT
            KY_phys = KY * NORMALIZATION_UNIT
            KZ_phys = KZ * NORMALIZATION_UNIT
            dists = []
            for g in G:
                d2 = (KX_phys - g[0]) ** 2 + (KY_phys - g[1]) ** 2 + (KZ_phys - g[2]) ** 2
                dists.append(d2)
            all_bands = np.stack(dists, axis=0)
            all_bands.sort(axis=0)
            if band_index < len(G): E = all_bands[band_index]
            else: E = all_bands[-1]
            return k_range_int, E

        if args.mode == "nfe":
            it = np.nditer([KX, KY, KZ], flags=["multi_index"])
            for kx, ky, kz in it:
                kv = np.array([float(kx), float(ky), float(kz)], dtype=float)
                evals = get_band_energies_at_k_dim(kv, "nfe", 3, args)
                E[it.multi_index] = evals[band_index]
            return k_range_int, E

    raise ValueError("dim must be 1,2,3")


# =============================================================================
# Fermi surface plotters (1dfs/2dfs/3dfs)
def plot_1d_fermi_points(k_range_int: np.ndarray, E_line: np.ndarray, EF: float, args):
    # ... (中略：元のコードと同じ)
    fig, ax = plt.subplots(figsize=(7, 4))
    ax.plot(k_range_int, E_line, linewidth=2.0)
    ax.axhline(EF, linestyle="--", linewidth=2, label=r"$E_F$")
    roots = []
    Ek = E_line - EF
    for i in range(len(k_range_int) - 1):
        a, b = k_range_int[i], k_range_int[i + 1]
        fa, fb = Ek[i], Ek[i + 1]
        if fa == 0: roots.append(a)
        if fa * fb > 0: continue
        if fb != fa:
            r = a + (b - a) * (-fa) / (fb - fa)
            roots.append(r)
    roots_red = sorted({round(reduce_to_first_bz(r), 12) for r in roots})
    for r in roots_red: ax.axvline(r, linestyle=":", linewidth=1.5)
    if roots_red: ax.plot([], [], linestyle=":", label=r"$k_F$")
    ax.set_xlabel(r"$k / (2\pi/a)$")
    ax.set_ylabel(r"Energy $E$")
    ax.set_title(rf"1D Fermi points ({args.mode})   $E_F={EF:.6g}$")
    ax.grid(True, linestyle="--", alpha=0.5)
    ax.legend(loc="best")
    plt.tight_layout()
    plt.show()


def plot_2d_fermi_surface(E_grid_2d: np.ndarray, EF: float, args, k_range_int: np.ndarray):
    # ... (中略：元のコードと同じ)
    view_limit = args.nx_view
    N_TILES = int(np.ceil(2 * view_limit / 0.5))
    if N_TILES % 2 == 0: N_TILES += 1
    E_tile = np.tile(E_grid_2d, (N_TILES, N_TILES))
    k_min_total = -0.5 * N_TILES
    k_max_total = 0.5 * N_TILES
    k_range_ext = np.linspace(k_min_total, k_max_total, len(k_range_int) * N_TILES)
    fig, ax = plt.subplots(figsize=(8, 8))
    E_min = float(np.min(E_tile))
    ax.contourf(k_range_ext, k_range_ext, E_tile.T, levels=[E_min, EF], colors=["lightblue"], alpha=0.6, extend="neither")
    ax.contour(k_range_ext, k_range_ext, E_tile.T, levels=[EF], colors=["red"], linewidths=2)
    limit_int = 0.5
    ax.set_xlim(-view_limit, view_limit)
    ax.set_ylim(-view_limit, view_limit)
    ax.plot([limit_int, limit_int], [-view_limit, view_limit], "k--", lw=1)
    ax.plot([-limit_int, -limit_int], [-view_limit, view_limit], "k--", lw=1)
    ax.plot([-view_limit, view_limit], [limit_int, limit_int], "k--", lw=1)
    ax.plot([-view_limit, view_limit], [-limit_int, -limit_int], "k--", lw=1)
    ax.set_aspect("equal")
    ax.set_xlabel(r"$k_x / (2\pi/a)$")
    ax.set_ylabel(r"$k_y / (2\pi/a)$")
    ax.set_title(rf"2D Fermi Surface (Periodic, {args.mode})   $E_F={EF:.6g}$")
    ax.grid(True, linestyle="--", alpha=0.5)
    ax.fill_between([], [], color="lightblue", alpha=0.6, label=r"Occupied ($E \leq E_F$)")
    ax.legend(loc="lower left")
    plt.tight_layout()
    plt.show()


def plot_3d_fermi_surface(E_grid_3d: np.ndarray, EF: float, args, k_range_int: np.ndarray):
    # ... (中略：元のコードと同じ)
    print("[INFO] Marching Cubes...")
    try:
        from skimage import measure
    except ImportError:
        print("\n[ERROR] 'scikit-image' is required for 3D Fermi surface plotting.")
        print("Please install it via: pip install scikit-image\n")
        return

    try:
        verts, faces, normals, values = measure.marching_cubes(E_grid_3d, EF, spacing=(1, 1, 1))
        step_size = 1.0 / (len(k_range_int) - 1)
        verts = verts * step_size - 0.5  # [0..N] -> [-0.5..0.5]
    except ValueError:
        print("[ERROR] Isosurface could not be generated at this EF.")
        return

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection="3d")
    ax.set_box_aspect([1, 1, 1])

    tris = verts[faces]
    v1 = tris[:, 1] - tris[:, 0]
    v2 = tris[:, 2] - tris[:, 0]
    face_normals = np.cross(v1, v2)
    norm_mag = np.linalg.norm(face_normals, axis=1)
    norm_mag[norm_mag == 0] = 1.0
    face_normals = face_normals / norm_mag[:, np.newaxis]

    light_A = np.array([1.0, 1.0, 1.0])
    light_A /= np.linalg.norm(light_A)
    base_color_A = "cyan" if args.mode != "tb" else "orange"
    rgb_A = np.array(mcolors.to_rgb(base_color_A))
    base_color_B = "navy" if args.mode != "tb" else "darkred"
    rgb_B = np.array(mcolors.to_rgb(base_color_B))

    dots = np.dot(face_normals, light_A)
    weights = (dots + 1.0) / 2.0
    weights = np.clip(weights, 0.0, 1.0)[:, np.newaxis]

    face_colors_rgb = weights * rgb_A + (1.0 - weights) * rgb_B
    alpha_val = args.fs_alpha
    face_colors_rgba = np.hstack((face_colors_rgb, np.full((len(faces), 1), alpha_val)))

    mesh = Poly3DCollection(verts[faces], facecolors=face_colors_rgba)
    mesh.set_edgecolor("black")
    mesh.set_linewidth(0.05)
    ax.add_collection3d(mesh)

    limit = 0.5
    ax.set_xlim(-limit, limit)
    ax.set_ylim(-limit, limit)
    ax.set_zlim(-limit, limit)

    title_str = rf"3D Fermi Surface ({args.mode})   $E_F={EF:.6g}$" + "\n" + r"(Dual Light Shading)"
    ax.set_title(title_str)
    ax.set_xlabel(r"$k_x / (2\pi/a)$")
    ax.set_ylabel(r"$k_y / (2\pi/a)$")
    ax.set_zlabel(r"$k_z / (2\pi/a)$")
    plt.show()


# =============================================================================
# Dispatcher (main function)
# =============================================================================

def dim_from_type(t: str) -> int:
    if t in ("1dband", "1dfs"): return 1
    if t in ("2dband", "2dfs"): return 2
    if t in ("3dband", "3dfs"): return 3
    raise ValueError("bad --type")


def main():
    parser = argparse.ArgumentParser(
        description="Toy band / DOS / EF(nelec from DOS integral) / kF / m*(k) and FS (1dfs/2dfs/3dfs) simulator."
    )

    parser.add_argument("--type", type=str, default="3dband",
                        choices=["1dband", "2dband", "3dband", "1dfs", "2dfs", "3dfs"],
                        help="Output type: band plots or Fermi surface plots.")

    parser.add_argument("--mode", type=str, default="tb",
                        choices=["free", "tb", "nfe"],
                        help="Model: free / tb / nfe (toy NFE Hamiltonian).")

    # --- 新規追加オプション ---
    parser.add_argument("--ef_fixed", type=float, default=None,
                        help="Fixed Fermi energy (EF). If set, skips DOS calculation (nelec is ignored).")
    # --------------------------

    # electron count
    parser.add_argument("--nelec", type=float, default=1.0,
                        help="Electrons per unit cell (including spin). EF is obtained from DOS integral (ignored if --ef_fixed is set).")
    parser.add_argument("--spin", type=float, default=2.0,
                        help="Spin degeneracy factor (default 2).")

    # grid / resolution
    parser.add_argument("--res", type=int, default=40,
                        help="Resolution for FS grids / band path segments.")
    parser.add_argument("--ef_res", type=int, default=24,
                        help="BZ sampling resolution per axis for DOS/EF (cost grows as ef_res^dim; ignored if --ef_fixed is set).")

    # NFE coupling
    parser.add_argument("--v", type=float, default=2.0,
                        help="Off-diagonal coupling strength in toy NFE Hamiltonian.")

    # band path
    parser.add_argument("--k_path_range", type=float, nargs=2, default=[-0.5, 0.5],
                        metavar=("KMIN", "KMAX"),
                        help="kx range for 1D band/1D cut plots (internal unit).")

    # plot Y-range (visual only)
    parser.add_argument("--E_range", type=float, nargs=2, default=None, metavar=("EMIN", "EMAX"),
                        help="Energy axis range for plotting (visual only).")

    # EF/DOS calculation range (important!)
    parser.add_argument("--E_ef_range", type=float, nargs=2, default=None, metavar=("EMIN", "EMAX"),
                        help="Energy range used for DOS/EF calculation. Default: auto from sampled energies. (ignored if --ef_fixed is set).")

    # DOS options
    parser.add_argument("--dos", action="store_true",
                        help="Also plot DOS (next to band). Requires DOS calculation (--ef_fixed must be None).")
    parser.add_argument("--dos_nE", type=int, default=800,
                        help="Energy grid points for DOS.")
    parser.add_argument("--dos_sigma", type=float, default=0.15,
                        help="Gaussian broadening sigma for DOS (<=0 => histogram).")

    # kF roots controls
    parser.add_argument("--kf_tol", type=float, default=1e-10,
                        help="Tolerance for kF root finding.")
    parser.add_argument("--kf_merge_tol", type=float, default=1e-4,
                        help="Merge tolerance for reduced kF lines (internal unit).")

    # effective mass overlay
    parser.add_argument("--mstar", action="store_true",
                        help="Overlay effective mass m*(k) on band plot (right axis), clipped to [-5,5].")

    # FS options
    parser.add_argument("--nx_view", type=float, default=1.5,
                        help="2D FS periodic view range (internal unit).")
    parser.add_argument("--fs_band", type=int, default=0,
                        help="Band index used for FS isosurface/contour (default 0).")
    parser.add_argument("--fs_alpha", type=float, default=0.6,
                        help="Alpha for 3D FS surface.")

    args = parser.parse_args()
    dim = dim_from_type(args.type)

    # basic checks
    if args.spin <= 0: raise ValueError("--spin must be > 0")
    if args.nelec < 0: raise ValueError("--nelec must be >= 0")

    # --- EF/DOS計算ロジック分岐 ---
    E_samples = None
    Egrid_dos = None
    dos = None

    if args.ef_fixed is not None:
        # EFが固定値の場合
        EF = float(args.ef_fixed)
        print(f"[INFO] EF (Fixed) = {EF:.6g}")
        if args.dos:
            print("[WARN] --dos ignored because --ef_fixed is set (skipping time-consuming BZ sampling).")
            args.dos = False # DOSプロットを強制的にオフ
        
        # プロットレンジ決定のために、最小限のエネルギー範囲を計算する（ここでは省略し、プロットレンジはユーザー指定かデフォルトに依存）
        if args.E_range is None:
            print("[WARN] EF is fixed but --E_range is not set. Using default auto-range logic might be inaccurate.")
        
        E_range_ef = None # DOS計算はしないので不要
    
    else:
        # nelecからEFを計算する場合
        print(f"[INFO] Sampling energies in {dim}D BZ for DOS/EF (ef_res={args.ef_res}) ...")
        E_samples = sample_energies_in_bz(args, dim)

        Nband = nband_of_model(dim, args.mode)
        scale_states = float(args.spin) * Nband
        if args.nelec > scale_states + 1e-12:
            print(f"[WARN] nelec={args.nelec} exceeds max capacity spin*Nband={scale_states:.6g} "
                  f"for this toy basis; EF will saturate near top of sampled states.")

        E_range_ef = determine_energy_range(E_samples, args.E_ef_range)

        # DOS and EF from DOS integral
        Egrid_dos, dos = compute_dos(
            E_samples,
            nE=args.dos_nE,
            sigma=args.dos_sigma,
            E_range=E_range_ef,
            scale_states=scale_states
        )
        EF = ef_from_dos(args.nelec, Egrid_dos, dos, warn=True)
        print(f"[INFO] EF(from DOS integral) = {EF:.6g}   (nelec={args.nelec}, spin={args.spin}, Nband={Nband})")
    
    # ------------------------------------

    # プロット用レンジ決定（固定値/自動計算に関わらず）
    if args.E_range is not None:
        E_range_plot = (float(args.E_range[0]), float(args.E_range[1]))
    elif E_range_ef is not None:
        E_range_plot = E_range_ef
    else:
        # EF固定時にE_range指定がない場合のフォールバック（DOS計算はしないので、EF周辺を適当に表示）
        pad = 2.0
        E_range_plot = (EF - pad, EF + pad)
        
    # 5) dispatch
    if args.type in ("2dband", "3dband"):
        plot_band_with_path(args, dim, EF, E_range_plot)
        return 0

    if args.type == "1dband":
        # 1Dは直線プロット（DOS情報が必要なら渡す）
        plot_band(args, dim, EF, E_samples, E_range_plot, E_range_ef, Egrid_dos, dos)
        return 0

    # FS energy grids and plots use EF computed above
    if args.type == "1dfs":
        k_range_int, E_line = build_energy_grid_for_fs(args, 1)
        plot_1d_fermi_points(k_range_int, E_line, EF, args)
        return 0

    if args.type == "2dfs":
        k_range_int, E2 = build_energy_grid_for_fs(args, 2)
        plot_2d_fermi_surface(E2, EF, args, k_range_int)
        return 0

    if args.type == "3dfs":
        k_range_int, E3 = build_energy_grid_for_fs(args, 3)
        plot_3d_fermi_surface(E3, EF, args, k_range_int)
        return 0

    raise ValueError("unknown --type")


if __name__ == "__main__":
    raise SystemExit(main())