import csv
import numpy as np
from numpy import sin, cos, tan, pi
import sys
import matplotlib.pyplot as plt


x0 = 0.0
x1 = 5.0
ndata = 101
norder = 3
if len(sys.argv) >= 2:
    norder = int(sys.argv[1])

outfile  = 'lsq-polynomial.csv'


# three trigonometric functions with randomized frequencies
# A: amplitude, f: frequency, p1: phase    
def func(x):
    r = 50.0 * np.random.rand();
    return r + 0.5 + 0.1 * x -0.2 * x * x + 0.5 * x*x*x


def mlsq(x, y, m):
    n = len(x)
    Si  = np.empty([m+1, 1])
    Sij = np.empty([m+1, m+1])

    for l in range(0, m+1):
        v = 0.0
        for i in range(0, n):
            v += y[i] * pow(x[i], l)
        Si[l, 0] = v

    for j in range(0, m+1):
        for l in range(j, m+1):
            v = 0.0
            for i in range(0, n):
                v += pow(x[i], j+l)
            Sij[j, l] = Sij[l, j] = v

    print("Vector and Matrix:")
    print("Si=")
    print(Si)
    print("Sij=")
    print(Sij)
    print("")

    ci = np.linalg.inv(Sij) @ Si
    ci = ci.transpose().tolist()
    return ci[0]


def main():
    global x0, x1, ndata, norder
    global outfile

    print("Least-squares method for polynomial order {}".format(norder))
    x = np.zeros(ndata)
    y = np.zeros(ndata)
    print("Make data by func(x) with random scattering")
    xstep = (x1 - x0) / (ndata - 1)
    print("x = ({}, {}, {})".format(x0, x1, xstep))
    print("ndata={}".format(ndata))
    for i in range(0, ndata):
        x[i] = x0 + i * xstep
        y[i] = func(x[i])
    print("")

    ci = mlsq(x, y, norder)

    print("LSQ function")
    print("f(x) = {}".format(ci[0]), end = '')
    for i in range(1, norder+1):
        print(" + {} * x^{}".format(ci[i], i), end = '')
    print("")

    f = open(outfile, 'w')
    fout = csv.writer(f, lineterminator='\n')
    fout.writerow(['x', 'y', 'y(LSQ)'])
    yfit = []
    for i in range(0, ndata):
        yl = ci[0]
        for k in range(1, norder+1):
            yl += ci[k] * pow(x[i], k)
        yfit.append(yl)
        fout.writerow([x[i], y[i], yl])

#=============================
# Plot graphs
#=============================
    fig = plt.figure()

    ax1 = fig.add_subplot(1, 1, 1)

    ax1.plot(x, y,    label = 'raw data', linestyle = 'none', marker = 'o', markersize = 1.5)
    ax1.plot(x, yfit, label = 'fitted')
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.legend()
    plt.tight_layout()

    plt.pause(0.1)
    print("Press ENTER to exit>>", end = '')
    input()

    exit()


if __name__ == "__main__":
    main()
