"""tkplot.py
フィット結果の標準プロット補助。
細かい見た目はアプリ側で調整できるよう、
Figure/Axes を返す薄い関数にしている。
:doc:`tkplot_usage`
"""
from __future__ import annotations
from pathlib import Path
from typing import Callable, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
ArrayLike = Sequence[float] | np.ndarray
[ドキュメント]
def make_xcal(
x: ArrayLike,
*,
n: int = 401,
xmin: Optional[float] = None,
xmax: Optional[float] = None,
margin: float = 0.0,
) -> np.ndarray:
"""プロット用の滑らかな x 軸を生成します。
与えられたデータ点 `x` の範囲に基づいて、プロットに適した均一な間隔のx軸を生成します。
最小値、最大値、またはマージンを調整して、表示範囲を制御できます。
:param x: 既存のxデータ。このデータの範囲がデフォルトのx軸範囲の基準となります。
:type x: ArrayLike
:param n: 生成するx軸の点の数。デフォルトは401点です。
:type n: int
:param xmin: x軸の最小値を明示的に指定します。Noneの場合、`x`の最小値が使用されます。
:type xmin: Optional[float]
:param xmax: x軸の最大値を明示的に指定します。Noneの場合、`x`の最大値が使用されます。
:type xmax: Optional[float]
:param margin: x軸の範囲に対するマージンの割合。0.1を指定すると、範囲の10%が両端に追加されます。
:type margin: float
:returns: プロット用に生成された滑らかなx軸のデータ。
:rtype: np.ndarray
"""
x_arr = np.asarray(x, dtype=float).reshape(-1)
if xmin is None:
xmin = float(np.nanmin(x_arr))
if xmax is None:
xmax = float(np.nanmax(x_arr))
width = xmax - xmin
xmin = xmin - margin * width
xmax = xmax + margin * width
return np.linspace(xmin, xmax, int(n))
[ドキュメント]
def plot_fit_before_after(
x: ArrayLike,
y: ArrayLike,
model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike],
p_before: Mapping[str, float],
*,
p_after: Optional[Mapping[str, float]] = None,
xcal: Optional[ArrayLike] = None,
yerr: Optional[ArrayLike] = None,
band: Optional[Mapping[str, ArrayLike]] = None,
xlabel: str = "x",
ylabel: str = "y",
title: str = "fit result",
data_label: str = "data",
before_label: str = "before",
after_label: str = "after",
out_png: Optional[Union[str, Path]] = None,
show: bool = False,
close: bool = True,
):
"""データ、フィット前、フィット後を重ねて描画します。
観測データ、フィット前のモデル予測、フィット後のモデル予測を一つのグラフ上に表示します。
オプションで、データの誤差棒やモデルの誤差帯を追加することも可能です。
生成されたFigureとAxesオブジェクトを返却し、さらに画像ファイルとして保存したり、
画面に表示したりすることもできます。
:param x: 観測データのx座標。
:type x: ArrayLike
:param y: 観測データのy座標。
:type y: ArrayLike
:param model_func: モデル関数。`y = model_func(x_array, params_dict)` の形式で呼び出されます。
:type model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike]
:param p_before: フィット前のモデルパラメータを含む辞書。
:type p_before: Mapping[str, float]
:param p_after: (オプション) フィット後のモデルパラメータを含む辞書。指定しない場合、フィット後の曲線はプロットされません。
:type p_after: Optional[Mapping[str, float]]
:param xcal: (オプション) モデル曲線を描画するためのx軸データ。指定しない場合、`make_xcal`によって自動生成されます。
:type xcal: Optional[ArrayLike]
:param yerr: (オプション) データのy方向の誤差。指定した場合、エラーバーがプロットされます。
:type yerr: Optional[ArrayLike]
:param band: (オプション) モデルの不確実性帯域に関する情報を含む辞書。
{"x": xband, "y_low": y_low, "y_high": y_high} または
{"x": xband, "y_mean": y_mean, "sigma": sigma} の形式。
:type band: Optional[Mapping[str, ArrayLike]]
:param xlabel: x軸のラベル文字列。
:type xlabel: str
:param ylabel: y軸のラベル文字列。
:type ylabel: str
:param title: プロットのタイトル文字列。
:type title: str
:param data_label: データ点に表示する凡例ラベル。
:type data_label: str
:param before_label: フィット前の曲線に表示する凡例ラベル。
:type before_label: str
:param after_label: フィット後の曲線に表示する凡例ラベル。
:type after_label: str
:param out_png: (オプション) プロットを保存するPNGファイルのパス。指定しない場合、ファイルは保存されません。
:type out_png: Optional[Union[str, Path]]
:param show: プロットを表示するかどうか。Trueの場合、`plt.show()`が呼び出されます。
:type show: bool
:param close: プロットを閉じるかどうか。Trueの場合、`plt.close(fig)`が呼び出されます。
`show=True`の場合でも、明示的に閉じたい場合に利用します。
:type close: bool
:returns: 生成されたmatplotlibのFigureオブジェクトとAxesオブジェクト。
:rtype: Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]
"""
import matplotlib.pyplot as plt
x_arr = np.asarray(x, dtype=float).reshape(-1)
y_arr = np.asarray(y, dtype=float).reshape(-1)
if xcal is None:
xcal_arr = make_xcal(x_arr)
else:
xcal_arr = np.asarray(xcal, dtype=float).reshape(-1)
fig, ax = plt.subplots(figsize=(8, 5))
if yerr is None:
ax.scatter(x_arr, y_arr, s=28, color="black", alpha=0.75, label=data_label)
else:
ax.errorbar(
x_arr,
y_arr,
yerr=np.asarray(yerr, dtype=float).reshape(-1),
fmt="o",
ms=4,
color="black",
alpha=0.75,
label=data_label,
)
y_before = np.asarray(model_func(xcal_arr, p_before), dtype=float).reshape(-1)
ax.plot(xcal_arr, y_before, "--", color="tab:orange", lw=2, label=before_label)
if p_after is not None:
y_after = np.asarray(model_func(xcal_arr, p_after), dtype=float).reshape(-1)
ax.plot(xcal_arr, y_after, "-", color="tab:blue", lw=2, label=after_label)
if band is not None:
plot_band(ax, band)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.grid(True, alpha=0.3)
ax.legend()
fig.tight_layout()
if out_png is not None:
out_png = Path(out_png)
out_png.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(out_png, dpi=160)
if show:
plt.show()
if close:
plt.close(fig)
return fig, ax
[ドキュメント]
def plot_band(
ax,
band: Mapping[str, ArrayLike],
*,
color: str = "tab:blue",
alpha: float = 0.18,
label: str = "uncertainty",
):
"""既存のAxesオブジェクトに誤差帯(不確実性帯域)を追加します。
モデルの予測の信頼区間や、データの不確かさを視覚的に表現するために使用します。
`band`辞書には、xデータとyの低/高値、またはyの平均値と標準偏差を含める必要があります。
:param ax: 誤差帯を追加するmatplotlibのAxesオブジェクト。
:type ax: matplotlib.axes.Axes
:param band: 誤差帯に関する情報を含む辞書。
{"x": xband, "y_low": y_low, "y_high": y_high} または
{"x": xband, "y_mean": y_mean, "sigma": sigma} の形式である必要があります。
:type band: Mapping[str, ArrayLike]
:param color: 誤差帯の塗りつぶし色。
:type color: str
:param alpha: 誤差帯の透明度。0.0(完全に透明)から1.0(完全に不透明)の範囲。
:type alpha: float
:param label: 誤差帯の凡例ラベル。
:type label: str
:raises ValueError: `band`辞書に必要なキー('x'、および'y_low'/'y_high'または'y_mean'/'sigma')が含まれていない場合。
:returns: なし (Axesを直接変更します)。
:rtype: None
"""
if "x" not in band:
raise ValueError("band must contain 'x'")
x = np.asarray(band["x"], dtype=float).reshape(-1)
if "y_low" in band and "y_high" in band:
y_low = np.asarray(band["y_low"], dtype=float).reshape(-1)
y_high = np.asarray(band["y_high"], dtype=float).reshape(-1)
elif "y_mean" in band and "sigma" in band:
y_mean = np.asarray(band["y_mean"], dtype=float).reshape(-1)
sigma = np.asarray(band["sigma"], dtype=float).reshape(-1)
y_low = y_mean - sigma
y_high = y_mean + sigma
else:
raise ValueError("band must contain either y_low/y_high or y_mean/sigma")
ax.fill_between(x, y_low, y_high, color=color, alpha=alpha, label=label)
[ドキュメント]
def save_progress_plot(
iteration: int,
x: ArrayLike,
y: ArrayLike,
model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike],
p_before: Mapping[str, float],
p_current: Mapping[str, float],
*,
out_dir: Union[str, Path] = ".",
prefix: str = "fit_progress",
**kwargs,
) -> Path:
"""フィット途中の画像を保存します。主にコールバック関数からの呼び出しを想定しています。
フィットアルゴリズムの各イテレーションでのモデルの適合状況を視覚的に追跡するために、
指定されたディレクトリにPNG画像としてプロットを保存します。
ファイル名はイテレーション番号に基づいて自動的に生成され、進行状況を容易に確認できます。
:param iteration: 現在のフィットイテレーション番号。ファイル名に埋め込まれます。
:type iteration: int
:param x: 観測データのx座標。
:type x: ArrayLike
:param y: 観測データのy座標。
:type y: ArrayLike
:param model_func: モデル関数。`y = model_func(x_array, params_dict)` の形式で呼び出されます。
:type model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike]
:param p_before: フィット開始前のモデルパラメータを含む辞書。プロットの「before」曲線に使用されます。
:type p_before: Mapping[str, float]
:param p_current: 現在のイテレーションにおけるモデルパラメータを含む辞書。プロットの「after」曲線に使用されます。
:type p_current: Mapping[str, float]
:param out_dir: (オプション) 画像ファイルを保存する出力ディレクトリのパス。
:type out_dir: Union[str, Path]
:param prefix: (オプション) 保存される画像ファイルのプレフィックス。
ファイル名は `prefix_0001.png` のようになります。
:type prefix: str
:param kwargs: `plot_fit_before_after` に渡される追加のキーワード引数。
:returns: 保存された画像ファイルの完全なパス。
:rtype: Path
"""
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_png = out_dir / f"{prefix}_{int(iteration):04d}.png"
plot_fit_before_after(
x,
y,
model_func,
p_before,
p_after=p_current,
out_png=out_png,
title=f"fit progress iter={iteration}",
**kwargs,
)
return out_png