regression.tklsq.tkplot のソースコード

"""tkplot.py

フィット結果の標準プロット補助。

細かい見た目はアプリ側で調整できるよう、
Figure/Axes を返す薄い関数にしている。

:doc:`tkplot_usage`
"""

from __future__ import annotations

from pathlib import Path
from typing import Callable, Mapping, Optional, Sequence, Tuple, Union
import numpy as np


ArrayLike = Sequence[float] | np.ndarray


[ドキュメント] def make_xcal( x: ArrayLike, *, n: int = 401, xmin: Optional[float] = None, xmax: Optional[float] = None, margin: float = 0.0, ) -> np.ndarray: """プロット用の滑らかな x 軸を生成します。 与えられたデータ点 `x` の範囲に基づいて、プロットに適した均一な間隔のx軸を生成します。 最小値、最大値、またはマージンを調整して、表示範囲を制御できます。 :param x: 既存のxデータ。このデータの範囲がデフォルトのx軸範囲の基準となります。 :type x: ArrayLike :param n: 生成するx軸の点の数。デフォルトは401点です。 :type n: int :param xmin: x軸の最小値を明示的に指定します。Noneの場合、`x`の最小値が使用されます。 :type xmin: Optional[float] :param xmax: x軸の最大値を明示的に指定します。Noneの場合、`x`の最大値が使用されます。 :type xmax: Optional[float] :param margin: x軸の範囲に対するマージンの割合。0.1を指定すると、範囲の10%が両端に追加されます。 :type margin: float :returns: プロット用に生成された滑らかなx軸のデータ。 :rtype: np.ndarray """ x_arr = np.asarray(x, dtype=float).reshape(-1) if xmin is None: xmin = float(np.nanmin(x_arr)) if xmax is None: xmax = float(np.nanmax(x_arr)) width = xmax - xmin xmin = xmin - margin * width xmax = xmax + margin * width return np.linspace(xmin, xmax, int(n))
[ドキュメント] def plot_fit_before_after( x: ArrayLike, y: ArrayLike, model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike], p_before: Mapping[str, float], *, p_after: Optional[Mapping[str, float]] = None, xcal: Optional[ArrayLike] = None, yerr: Optional[ArrayLike] = None, band: Optional[Mapping[str, ArrayLike]] = None, xlabel: str = "x", ylabel: str = "y", title: str = "fit result", data_label: str = "data", before_label: str = "before", after_label: str = "after", out_png: Optional[Union[str, Path]] = None, show: bool = False, close: bool = True, ): """データ、フィット前、フィット後を重ねて描画します。 観測データ、フィット前のモデル予測、フィット後のモデル予測を一つのグラフ上に表示します。 オプションで、データの誤差棒やモデルの誤差帯を追加することも可能です。 生成されたFigureとAxesオブジェクトを返却し、さらに画像ファイルとして保存したり、 画面に表示したりすることもできます。 :param x: 観測データのx座標。 :type x: ArrayLike :param y: 観測データのy座標。 :type y: ArrayLike :param model_func: モデル関数。`y = model_func(x_array, params_dict)` の形式で呼び出されます。 :type model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike] :param p_before: フィット前のモデルパラメータを含む辞書。 :type p_before: Mapping[str, float] :param p_after: (オプション) フィット後のモデルパラメータを含む辞書。指定しない場合、フィット後の曲線はプロットされません。 :type p_after: Optional[Mapping[str, float]] :param xcal: (オプション) モデル曲線を描画するためのx軸データ。指定しない場合、`make_xcal`によって自動生成されます。 :type xcal: Optional[ArrayLike] :param yerr: (オプション) データのy方向の誤差。指定した場合、エラーバーがプロットされます。 :type yerr: Optional[ArrayLike] :param band: (オプション) モデルの不確実性帯域に関する情報を含む辞書。 {"x": xband, "y_low": y_low, "y_high": y_high} または {"x": xband, "y_mean": y_mean, "sigma": sigma} の形式。 :type band: Optional[Mapping[str, ArrayLike]] :param xlabel: x軸のラベル文字列。 :type xlabel: str :param ylabel: y軸のラベル文字列。 :type ylabel: str :param title: プロットのタイトル文字列。 :type title: str :param data_label: データ点に表示する凡例ラベル。 :type data_label: str :param before_label: フィット前の曲線に表示する凡例ラベル。 :type before_label: str :param after_label: フィット後の曲線に表示する凡例ラベル。 :type after_label: str :param out_png: (オプション) プロットを保存するPNGファイルのパス。指定しない場合、ファイルは保存されません。 :type out_png: Optional[Union[str, Path]] :param show: プロットを表示するかどうか。Trueの場合、`plt.show()`が呼び出されます。 :type show: bool :param close: プロットを閉じるかどうか。Trueの場合、`plt.close(fig)`が呼び出されます。 `show=True`の場合でも、明示的に閉じたい場合に利用します。 :type close: bool :returns: 生成されたmatplotlibのFigureオブジェクトとAxesオブジェクト。 :rtype: Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes] """ import matplotlib.pyplot as plt x_arr = np.asarray(x, dtype=float).reshape(-1) y_arr = np.asarray(y, dtype=float).reshape(-1) if xcal is None: xcal_arr = make_xcal(x_arr) else: xcal_arr = np.asarray(xcal, dtype=float).reshape(-1) fig, ax = plt.subplots(figsize=(8, 5)) if yerr is None: ax.scatter(x_arr, y_arr, s=28, color="black", alpha=0.75, label=data_label) else: ax.errorbar( x_arr, y_arr, yerr=np.asarray(yerr, dtype=float).reshape(-1), fmt="o", ms=4, color="black", alpha=0.75, label=data_label, ) y_before = np.asarray(model_func(xcal_arr, p_before), dtype=float).reshape(-1) ax.plot(xcal_arr, y_before, "--", color="tab:orange", lw=2, label=before_label) if p_after is not None: y_after = np.asarray(model_func(xcal_arr, p_after), dtype=float).reshape(-1) ax.plot(xcal_arr, y_after, "-", color="tab:blue", lw=2, label=after_label) if band is not None: plot_band(ax, band) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title(title) ax.grid(True, alpha=0.3) ax.legend() fig.tight_layout() if out_png is not None: out_png = Path(out_png) out_png.parent.mkdir(parents=True, exist_ok=True) fig.savefig(out_png, dpi=160) if show: plt.show() if close: plt.close(fig) return fig, ax
[ドキュメント] def plot_band( ax, band: Mapping[str, ArrayLike], *, color: str = "tab:blue", alpha: float = 0.18, label: str = "uncertainty", ): """既存のAxesオブジェクトに誤差帯(不確実性帯域)を追加します。 モデルの予測の信頼区間や、データの不確かさを視覚的に表現するために使用します。 `band`辞書には、xデータとyの低/高値、またはyの平均値と標準偏差を含める必要があります。 :param ax: 誤差帯を追加するmatplotlibのAxesオブジェクト。 :type ax: matplotlib.axes.Axes :param band: 誤差帯に関する情報を含む辞書。 {"x": xband, "y_low": y_low, "y_high": y_high} または {"x": xband, "y_mean": y_mean, "sigma": sigma} の形式である必要があります。 :type band: Mapping[str, ArrayLike] :param color: 誤差帯の塗りつぶし色。 :type color: str :param alpha: 誤差帯の透明度。0.0(完全に透明)から1.0(完全に不透明)の範囲。 :type alpha: float :param label: 誤差帯の凡例ラベル。 :type label: str :raises ValueError: `band`辞書に必要なキー('x'、および'y_low'/'y_high'または'y_mean'/'sigma')が含まれていない場合。 :returns: なし (Axesを直接変更します)。 :rtype: None """ if "x" not in band: raise ValueError("band must contain 'x'") x = np.asarray(band["x"], dtype=float).reshape(-1) if "y_low" in band and "y_high" in band: y_low = np.asarray(band["y_low"], dtype=float).reshape(-1) y_high = np.asarray(band["y_high"], dtype=float).reshape(-1) elif "y_mean" in band and "sigma" in band: y_mean = np.asarray(band["y_mean"], dtype=float).reshape(-1) sigma = np.asarray(band["sigma"], dtype=float).reshape(-1) y_low = y_mean - sigma y_high = y_mean + sigma else: raise ValueError("band must contain either y_low/y_high or y_mean/sigma") ax.fill_between(x, y_low, y_high, color=color, alpha=alpha, label=label)
[ドキュメント] def save_progress_plot( iteration: int, x: ArrayLike, y: ArrayLike, model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike], p_before: Mapping[str, float], p_current: Mapping[str, float], *, out_dir: Union[str, Path] = ".", prefix: str = "fit_progress", **kwargs, ) -> Path: """フィット途中の画像を保存します。主にコールバック関数からの呼び出しを想定しています。 フィットアルゴリズムの各イテレーションでのモデルの適合状況を視覚的に追跡するために、 指定されたディレクトリにPNG画像としてプロットを保存します。 ファイル名はイテレーション番号に基づいて自動的に生成され、進行状況を容易に確認できます。 :param iteration: 現在のフィットイテレーション番号。ファイル名に埋め込まれます。 :type iteration: int :param x: 観測データのx座標。 :type x: ArrayLike :param y: 観測データのy座標。 :type y: ArrayLike :param model_func: モデル関数。`y = model_func(x_array, params_dict)` の形式で呼び出されます。 :type model_func: Callable[[np.ndarray, Mapping[str, float]], ArrayLike] :param p_before: フィット開始前のモデルパラメータを含む辞書。プロットの「before」曲線に使用されます。 :type p_before: Mapping[str, float] :param p_current: 現在のイテレーションにおけるモデルパラメータを含む辞書。プロットの「after」曲線に使用されます。 :type p_current: Mapping[str, float] :param out_dir: (オプション) 画像ファイルを保存する出力ディレクトリのパス。 :type out_dir: Union[str, Path] :param prefix: (オプション) 保存される画像ファイルのプレフィックス。 ファイル名は `prefix_0001.png` のようになります。 :type prefix: str :param kwargs: `plot_fit_before_after` に渡される追加のキーワード引数。 :returns: 保存された画像ファイルの完全なパス。 :rtype: Path """ out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) out_png = out_dir / f"{prefix}_{int(iteration):04d}.png" plot_fit_before_after( x, y, model_func, p_before, p_after=p_current, out_png=out_png, title=f"fit progress iter={iteration}", **kwargs, ) return out_png