フィッティング範囲を指定、1.多項式最小二乗法を行い、グラフにプロット

06-leastsq-plot-range.py

# scipy.optimizeを使って多項式最小二乗の結果をグラフにプロットしてみる
# http://ailaby.com/least_square/
# https://qiita.com/Morio/items/d75159bac916174e7654
# https://rikei-fufu.com/2019/06/28/post-1357-python-plot/

import csv
from pprint import pprint
import numpy as np
from scipy import optimize
from matplotlib import pyplot as plt


"""
フィッティング範囲を指定、1.多項式最小二乗法を行い、グラフにプロット
"""



#=============================
# 大域変数の定義
#=============================
# CSVファイル
infile = 'data.csv'

# フィッティングパラメータ初期値。線形最小二乗の場合は適当
ai0 = [0, 0, 0]

fitrange = [0.5, 5.5]

print("fit range: ", fitrange)

#=============================
# 最小化する関数の定義
#=============================
def ycal(ai, x):
    return ai[0] + ai[1] * x + ai[2] * x * x

def residual(ai, x, y):
    res = []
    for i in range(len(x)):
        res.append(y[i] - ycal(ai, x[i]))
    return res


#=============================
# csvファイルの読み込み
#=============================
i = 0
x = []
y = []
with open(infile, "r") as f:
    reader = csv.reader(f)

    for row in reader:
        if i == 0:
            header = row
        else:
            xi = float(row[0])
            if fitrange[0] <= xi <= fitrange[1]:
                x.append(xi)
                y.append(float(row[1]))
        i += 1

print("header:", header)

print("x:", x)
print("y:", y)


#=============================
# scipy.optimize()による最小化
#=============================
print("")
print("polynomial fit by scipy.optimize() start:")
# leastsqの戻り値は、最適化したパラメータのリストと、最適化の結果
ai, ret = optimize.leastsq(residual, ai0, args= (x, y))
print(" lsq result: ai=", ai)
print(" y = {} + {} * x + {} * x^2".format(ai[0], ai[1], ai[2]))

xmin = min(x)
xmax = max(x)
ncal = 100
xstep = (xmax - xmin) / (ncal - 1)
xc = []
yc = []
for i in range(ncal):
    xi = xmin + i * xstep
    yi = ycal(ai, xi)
    xc.append(xi)
    yc.append(yi)

plt.plot(x, y, label='raw data', marker = 'o', linestyle = 'None')
plt.plot(xc, yc, label='fitted', linestyle = 'dashed')
plt.xlabel(header[0])
plt.ylabel(header[1])
plt.legend()

plt.show()