#python fit_minimize.py  --mode read --input data.csv --p0 1.0 0.0 2.0


#!/usr/bin/env python3
import argparse
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize


def read_data(path: Path):
    """
    Read 2-column data (x, y) from CSV or TSV.
    """
    if not path.exists():
        print(f"Error: input file {path} not found.", file=sys.stderr)
        sys.exit(1)

    # Try to infer delimiter from suffix
    if path.suffix.lower() in [".tsv", ".txt"]:
        delimiter = "\t"
    else:
        delimiter = ","

    try:
        data = np.loadtxt(path, delimiter=delimiter, skiprows=1)
    except ValueError:
        # Fallback: try the other delimiter
        delimiter = "\t" if delimiter == "," else ","
        data = np.loadtxt(path, delimiter=delimiter)

    if data.ndim != 2 or data.shape[1] < 2:
        print("Error: input data must have at least two columns (x, y).", file=sys.stderr)
        sys.exit(1)

    x = data[:, 0]
    y = data[:, 1]
    return x, y


def model_function(x, params):
    """
    f(x) = A / (1 + (x - omega0)^s)
    params = [A, omega0, tau0, s]
    """
    A, omega0, tau0, s = params
    return A / (1.0 + ((x - omega0) / tau0) ** s)


def residuals(params, x, y):
    """
    Residuals vector: f(x, params) - y
    """
    return model_function(x, params) - y


def rss(params, x, y):
    """
    Residual sum of squares: sum((f(x, params) - y)^2)
    """
    r = residuals(params, x, y)
    return np.sum(r * r)


def plot_data(x, y, ax=None, label="data", **kwargs):
    if ax is None:
        ax = plt.gca()
    ax.scatter(x, y, label=label, **kwargs)
    return ax


def plot_model(x, params, ax=None, label="model", **kwargs):
    if ax is None:
        ax = plt.gca()
    y_model = model_function(x, params)
    ax.plot(x, y_model, label=label, **kwargs)
    return ax


def mode_read(x, y):
    fig, ax = plt.subplots()
    plot_data(x, y, ax=ax, label="data", s=10)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title("mode=read: input data")
    ax.legend()
    plt.tight_layout()
    plt.show()


def mode_sim(x, y, p0):
    rss0 = rss(p0, x, y)
    fig, ax = plt.subplots()
    plot_data(x, y, ax=ax, label="data", s=10)
    plot_model(x, p0, ax=ax, label=f"initial model (RSS={rss0:.3g})", color="red")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title("mode=sim: initial parameters vs data")
    ax.legend()
    plt.tight_layout()
    plt.show()

    # 簡単な「初期値が離れている」判定（データ分散と比較）
    var_y = np.var(y)
    if var_y > 0 and rss0 > 10 * len(x) * var_y:
        print(
            f"Warning: initial RSS ({rss0:.3g}) is much larger than data variance scale "
            f"(~{len(x) * var_y:.3g}). Initial parameters may be far from optimum.",
            file=sys.stderr,
        )


def mode_fit(x, y, p0, method, nmaxiter, tol):
    iter_counter = {"n": 0}

    def callback(p):
        iter_counter["n"] += 1
        current_rss = rss(p, x, y)
        print(
            f"iter {iter_counter['n']:4d}: "
            f"A={p[0]:.6g}, omega0={p[1]:.6g}, tau={p[2]:.6g}, s={p[3]:.6g}, RSS={current_rss:.6g}"
        )

    options = {}
    if nmaxiter is not None:
        options["maxiter"] = nmaxiter
    if tol is None:
        tol = 1e-8

    print("Starting minimize...")
    print(f"  method   : {method}")
    print(f"  p0       : A={p0[0]}, omega0={p0[1]}, s={p0[2]}")
    print(f"  nmaxiter : {nmaxiter}")
    print(f"  tol      : {tol}")

    res = minimize(
        rss,
        x0=np.array(p0, dtype=float),
        args=(x, y),
        method=method,
        tol=tol,
        callback=callback,
        options=options,
    )

    print("\nOptimization finished.")
    print(f"  success      : {res.success}")
    print(f"  message      : {res.message}")
    print(f"  niter        : {res.nit if hasattr(res, 'nit') else iter_counter['n']}")
    print(f"  final RSS    : {res.fun:.6g}")
    A_fit, omega0_fit, tau_fit, s_fit = res.x
    print(f"  fitted params: A={A_fit:.6g}, omega0={omega0_fit:.6g}, tau={tau_fit:.6g}, s={s_fit:.6g}")

    # Plot before/after
    fig, ax = plt.subplots()
    plot_data(x, y, ax=ax, label="data", s=10)
    plot_model(x, p0, ax=ax, label="initial model", color="red", linestyle="--")
    plot_model(x, res.x, ax=ax, label="fitted model", color="green")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_title(f"mode=fit: method={method}")
    ax.legend()
    plt.tight_layout()
    plt.show()


def build_argparser():
    parser = argparse.ArgumentParser(
        description="Nonlinear least-squares fitting using scipy.optimize.minimize"
    )
    parser.add_argument(
        "--mode",
        choices=["read", "sim", "fit"],
        required=True,
        help="Execution mode: read (data only), sim (simulate with initial params), fit (perform fitting)",
    )
    parser.add_argument(
        "--input",
        type=Path,
        required=True,
        help="Input data file (CSV or TSV, first column x, second column y)",
    )
    parser.add_argument(
        "--method",
        type=str,
        default="Nelder-Mead",
        help="Optimization method for scipy.optimize.minimize "
             "(e.g., Nelder-Mead, BFGS, Powell, CG, L-BFGS-B, ...)",
    )
    parser.add_argument(
        "--nmaxiter",
        type=int,
        default=None,
        help="Maximum number of iterations for optimizer (passed to maxiter option)",
    )
    parser.add_argument(
        "--tol",
        type=float,
        default=1e-8,
        help="Convergence tolerance (tol argument of minimize)",
    )
    parser.add_argument(
        "--p0",
        type=float,
        nargs=4,
        metavar=("A0", "OMEGA0_0", "TAU0", "S0"),
        required=True,
        help="Initial parameters: A0 OMEGA0_0 TAU_0 S0",
    )
    return parser


def main():
    parser = build_argparser()
    args = parser.parse_args()

    x, y = read_data(args.input)
    p0 = np.array(args.p0, dtype=float)

    print(f"Loaded data from {args.input} (N={len(x)})")
    print(f"Initial parameters: A={p0[0]}, omega0={p0[1]}, tau={p0[2]}, s={p0[3]}")

    if args.mode == "read":
        mode_read(x, y)
    elif args.mode == "sim":
        mode_sim(x, y, p0)
    elif args.mode == "fit":
        mode_fit(x, y, p0, args.method, args.nmaxiter, args.tol)
    else:
        print(f"Unknown mode: {args.mode}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    main()
