cms.llsq.lsq_general のソースコード

import csv
import numpy as np
from numpy import sin, cos, tan, pi
import sys
import matplotlib.pyplot as plt


x0 = 0.0
x1 = 5.0
ndata = 101
nfunc = 7

if __name__ == "__main__":
    if len(sys.argv) >= 2:
        nfunc = int(sys.argv[1])

outfile  = 'lsq-general.csv'

[ドキュメント] def func(x): r = 0.5 * np.random.rand(); return r + 0.5 + 0.1 * sin(2.0*x) + 0.3 * cos(2.0*x)
flabel = ['1', 'sin(x)', 'cos(x)', 'sin(2x)', 'cos(2x)', 'sin(3x)', 'cos(3x)']
[ドキュメント] def lsqfunc(i, x): if i == 0: return 1.0 elif i == 1: return sin(x) elif i == 2: return cos(x) elif i == 3: return sin(2.0*x) elif i == 4: return cos(2.0*x) elif i == 5: return sin(3.0*x) elif i == 6: return cos(3.0*x)
[ドキュメント] def mlsq(x, y, m): n = len(x) Si = np.empty([m, 1]) Sij = np.empty([m, m]) for l in range(0, m): v = 0.0 for i in range(0, n): v += y[i] * lsqfunc(l, x[i]) Si[l, 0] = v for j in range(0, m): for l in range(j, m): v = 0.0 for i in range(0, n): v += lsqfunc(j, x[i]) * lsqfunc(l, x[i]) Sij[j, l] = Sij[l, j] = v print("Vector and Matrix:") print("Si=") print(Si) print("Sij=") print(Sij) print("") ci = np.linalg.inv(Sij) @ Si ci = ci.transpose().tolist() return ci[0]
[ドキュメント] def main(): global x0, x1, ndata, nfunc global outfile print("Least-squares method for sum of {} functions".format(nfunc)) x = np.zeros(ndata) y = np.zeros(ndata) print("Make data by func(x) with random scattering") xstep = (x1 - x0) / (ndata - 1) print("x = ({}, {}, {})".format(x0, x1, xstep)) print("ndata={}".format(ndata)) for i in range(0, ndata): x[i] = x0 + i * xstep y[i] = func(x[i]) print("") ci = mlsq(x, y, nfunc) print("LSQ function") print("f(x) = {}".format(ci[0]), end = '') for i in range(1, nfunc): print(" + {} * {}".format(ci[i], flabel[i]), end = '') print("") f = open(outfile, 'w') fout = csv.writer(f, lineterminator='\n') fout.writerow(['x', 'y', 'y(LSQ)']) yfit = [] for i in range(0, ndata): yl = ci[0] for k in range(1, nfunc): yl += ci[k] * lsqfunc(k, x[i]) yfit.append(yl) fout.writerow([x[i], y[i], yl]) #============================= # Plot graphs #============================= fig = plt.figure() ax1 = fig.add_subplot(1, 1, 1) ax1.plot(x, y, label = 'raw data', linestyle = 'none', marker = 'o', markersize = 1.5) ax1.plot(x, yfit, label = 'fitted') ax1.set_xlabel("x") ax1.set_ylabel("y") ax1.legend() plt.tight_layout() plt.pause(0.1) print("Press ENTER to exit>>", end = '') input() exit()
if __name__ == "__main__": main()