cms.llsq.lsq_polynomial のソースコード

import csv
import numpy as np
from numpy import sin, cos, tan, pi
import sys
import matplotlib.pyplot as plt


x0 = 0.0
x1 = 5.0
ndata = 101
norder = 3

if __name__ == "__main__":
    if len(sys.argv) >= 2:
        norder = int(sys.argv[1])

outfile  = 'lsq-polynomial.csv'


# three trigonometric functions with randomized frequencies
# A: amplitude, f: frequency, p1: phase    
[ドキュメント] def func(x): r = 50.0 * np.random.rand(); return r + 0.5 + 0.1 * x -0.2 * x * x + 0.5 * x*x*x
[ドキュメント] def mlsq(x, y, m): n = len(x) Si = np.empty([m+1, 1]) Sij = np.empty([m+1, m+1]) for l in range(0, m+1): v = 0.0 for i in range(0, n): v += y[i] * pow(x[i], l) Si[l, 0] = v for j in range(0, m+1): for l in range(j, m+1): v = 0.0 for i in range(0, n): v += pow(x[i], j+l) Sij[j, l] = Sij[l, j] = v print("Vector and Matrix:") print("Si=") print(Si) print("Sij=") print(Sij) print("") ci = np.linalg.inv(Sij) @ Si ci = ci.transpose().tolist() return ci[0]
[ドキュメント] def main(): global x0, x1, ndata, norder global outfile print("Least-squares method for polynomial order {}".format(norder)) x = np.zeros(ndata) y = np.zeros(ndata) print("Make data by func(x) with random scattering") xstep = (x1 - x0) / (ndata - 1) print("x = ({}, {}, {})".format(x0, x1, xstep)) print("ndata={}".format(ndata)) for i in range(0, ndata): x[i] = x0 + i * xstep y[i] = func(x[i]) print("") ci = mlsq(x, y, norder) print("LSQ function") print("f(x) = {}".format(ci[0]), end = '') for i in range(1, norder+1): print(" + {} * x^{}".format(ci[i], i), end = '') print("") f = open(outfile, 'w') fout = csv.writer(f, lineterminator='\n') fout.writerow(['x', 'y', 'y(LSQ)']) yfit = [] for i in range(0, ndata): yl = ci[0] for k in range(1, norder+1): yl += ci[k] * pow(x[i], k) yfit.append(yl) fout.writerow([x[i], y[i], yl]) #============================= # Plot graphs #============================= fig = plt.figure() ax1 = fig.add_subplot(1, 1, 1) ax1.plot(x, y, label = 'raw data', linestyle = 'none', marker = 'o', markersize = 1.5) ax1.plot(x, yfit, label = 'fitted') ax1.set_xlabel("x") ax1.set_ylabel("y") ax1.legend() plt.tight_layout() plt.pause(0.1) print("Press ENTER to exit>>", end = '') input() exit()
if __name__ == "__main__": main()