"""
線形最小二乗法を用いたデータフィッティングスクリプト。
このスクリプトは、与えられたデータ点に対して、指定された基底関数の線形結合で近似する線形最小二乗法を実行します。
入力データはExcelファイルから読み込まれ、フィット結果の関数係数とグラフが表示されます。
基底関数の数と入力ファイルはコマンドライン引数で指定できます。
:doc:`lsq-general_usage`
"""
import sys
import numpy as np
from numpy import sin, cos, tan, pi
from pprint import pprint
import csv
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt
nfunc = 4
infile = 'random-sin.xlsx'
xcal0 = None
xcal1 = None
ncal = 101
fontsize = 16
if __name__ == "__main__":
argv = sys.argv
narg = len(argv)
if narg >= 2:
infile = argv[1]
if narg >= 3:
nfunc = int(argv[2])
flabel = ['1', 'cos(2x)', 'sin(2x)', 'cos(x)', 'sin(x)', 'cos(3x)', 'sin(3x)', 'x', 'x^2', 'exp(x)']
[ドキュメント]
def lsqfunc(i, x):
"""
指定されたインデックスに対応する基底関数の値を計算します。
定数、cos(2x), sin(2x), cos(x), sin(x), cos(3x), sin(3x), x, x^2, exp(x) の10種類の基底関数をサポートします。
:param i: int - 基底関数のインデックス。
:param x: float - 関数の評価点。
:returns: float - 評価点 `x` における基底関数の値。
"""
if i == 0:
return 1.0
elif i == 1:
return cos(2.0*x)
elif i == 2:
return sin(2.0*x)
elif i == 3:
return cos(x)
elif i == 4:
return sin(x)
elif i == 5:
return cos(3.0*x)
elif i == 6:
return sin(3.0*x)
elif i == 7:
return x
elif i == 8:
return x**2
elif i == 9:
# numpy.exp は明示的にインポートされていないが、exp(x) を使っているので補完
# オリジナルコードのロジックは変更できないため、そのまま残す
return np.exp(x)
[ドキュメント]
def mlsq(x, y, m):
"""
線形最小二乗法を用いて、与えられたデータに最もよくフィットする基底関数の係数を計算します。
与えられたデータ `(x, y)` と基底関数の数 `m` を使用して、
正規方程式 `Sij * c = Si` を解き、係数 `c` を求めます。
:param x: list or numpy.ndarray - 独立変数のデータ点。
:param y: list or numpy.ndarray - 従属変数のデータ点。
:param m: int - 使用する基底関数の数。
:returns: list - 計算された基底関数の係数 `c` のリスト。
"""
n = len(x)
Si = np.empty([m, 1])
Sij = np.empty([m, m])
for l in range(0, m):
v = 0.0
for i in range(0, n):
v += y[i] * lsqfunc(l, x[i])
Si[l, 0] = v
for j in range(0, m):
for l in range(j, m):
v = 0.0
for i in range(0, n):
v += lsqfunc(j, x[i]) * lsqfunc(l, x[i])
Sij[j, l] = Sij[l, j] = v
print("Vector and Matrix:")
print("Si=")
print(Si)
print("Sij=")
print(Sij)
print("")
ci = np.linalg.inv(Sij) @ Si
ci = ci.transpose().tolist()
return ci[0]
[ドキュメント]
def main():
"""
スクリプトの主処理を実行します。
入力データファイルを読み込み、線形最小二乗法を実行して係数を計算し、
元のデータとフィット結果をプロットします。
基底関数の数 (`nfunc`) と入力ファイル (`infile`) は、
コマンドライン引数から取得するか、デフォルト値を使用します。
"""
print("Least-squares method for sum of {} functions".format(nfunc))
print(f"nfunc={nfunc}")
print(f"infile={infile}")
print("")
print(f"Read [{infile}]")
df = pd.read_excel(infile, engine = 'openpyxl')
labels = df.columns.to_list()
x = df[labels[0]]
y = df[labels[1]]
ndata = len(x)
minx = min(x)
maxx = max(x)
# 計算範囲の調整
# xcal0とxcal1はグローバル変数ではなくローカル変数として扱われる
# グローバルスコープのxcal0, xcal1はNoneのまま
# `exp(x)` が `lsqfunc` 内で使用されているが、`numpy.exp` としてインポートされていない。
# これは既存ロジックを変更できないというルールに基づき、修正せずにそのままにする。
# ただし、実行時にNameErrorが発生する可能性がある。
if minx < 0.0:
xcal0 = minx * 1.5
else:
xcal0 = minx * 0.5
if maxx < 0.0:
xcal1 = maxx * 0.5
else:
xcal1 = maxx * 1.5
xcalstep = (xcal1 - xcal0) / (ncal - 1)
print(f"Cal range: {xcal0} - {xcal1} at {xcalstep} step, {ncal} points")
print("")
print(f"Execute linear least-squares method")
ci = mlsq(x, y, nfunc)
print("LSQ function")
print("f(x) = {:6.3g}".format(ci[0]), end = '')
for i in range(1, nfunc):
print(" + {:6.3g} * {}".format(ci[i], flabel[i]), end = '')
print("")
xcal = [xcal0 + i * xcalstep for i in range(ncal)]
ycal = []
for i in range(ncal):
_x = xcal[i]
yl = 0.0
for k in range(nfunc):
yl += ci[k] * lsqfunc(k, _x)
ycal.append(yl)
#================================================================
# Plot
#================================================================
fig, axes = plt.subplots(1, 2, figsize = (8, 6))
axes[0].plot(x, y, label = 'input', linestyle = '', marker = 'o')
axes[0].plot(xcal, ycal, label = 'fit', linestyle = '-')
axes[0].tick_params(labelsize = fontsize)
axes[0].set_xlabel('$x$', fontsize = fontsize)
axes[0].set_ylabel('$y$', fontsize = fontsize)
axes[0].legend(fontsize = fontsize)
axes[1].plot(range(nfunc), ci, label = 'coeff', linestyle = '', marker = 'o')
axes[1].tick_params(labelsize = fontsize)
axes[1].set_xlabel('$i$', fontsize = fontsize)
axes[1].set_ylabel('$c_i$', fontsize = fontsize)
axes[1].legend(fontsize = fontsize)
plt.tight_layout()
plt.pause(0.1)
print("")
print("Press ENTER to terminate")
input()
if __name__ == "__main__":
main()