import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize


#method = "Nelder-Mead"
method = "bfgs"


def gaussian(x, A, mu, sigma):
    return A * np.exp(-(x - mu)**2 / (2 * sigma**2))


def baseline(x, b0, b1):
    return b0 + b1 * x


def model(x, p):
    A, mu, sigma, b0, b1 = p
    return gaussian(x, A, mu, sigma) + baseline(x, b0, b1)


def objective(p, x, y):
    return np.sum((y - model(x, p))**2)


def fit_with_visualization(x, y, p0):
    history_params = []
    history_error = []

    plt.ion()
    fig, ax = plt.subplots(figsize=(8, 5))

    # callback: 1 step ごとに呼ばれる
    def callback(p):
        err = objective(p, x, y)
        history_params.append(p.copy())
        history_error.append(err)

        print(f"step {len(history_params)}")
        print(f"params = {p}")
        print(f"error  = {err:.5f}")
        print("-" * 40)

        # グラフ更新
        ax.clear()
        ax.plot(x, y, "o", ms=4, label="data")
        ax.plot(x, model(x, p0), "--", label="initial")
        ax.plot(x, model(x, p), "-", label="current fit")
        ax.legend()
        ax.set_title(f"Iteration {len(history_params)}  Error={err:.4f}")
        plt.pause(0.1)

    # 最適化（Nelder-Mead は安定で可視化しやすい）
    result = minimize(
        objective,
        p0,
        args=(x, y),
        method=method,
        callback=callback,
        options={"maxiter": 200, "disp": True}
    )

    plt.ioff()
    plt.show()

    return result.x, history_params, history_error

if __name__ == "__main__":
    x = np.linspace(0, 10, 200)
    y_true = model(x, [5, 5, 0.8, 1.0, 0.1])
    y = y_true + 0.3 * np.random.randn(len(x))

    p0 = [3, 4, 1.0, 0.0, 0.0]  # わざと悪い初期値

    popt, params_hist, err_hist = fit_with_visualization(x, y, p0)

    print("最終パラメータ:", popt)
