import sys
import numpy as np
from numpy import sin, cos, tan, pi
from pprint import pprint
import csv
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt


norder = 3
infile  =  'random-poly.xlsx'

xcal0 = None
xcal1 = None
ncal = 101

fontsize = 16


argv = sys.argv
narg = len(argv)
if narg >= 2:
    infile = argv[1]
if narg >= 3:
    norder = int(argv[2])


def mlsq(x, y, m, iPrint = 0):
    n = len(x)
    Si  = np.empty([m+1, 1])
    Sij = np.empty([m+1, m+1])

    for l in range(0, m+1):
        v = 0.0
        for i in range(0, n):
            v += y[i] * pow(x[i], l)
        Si[l, 0] = v

    for j in range(0, m+1):
        for l in range(j, m+1):
            v = 0.0
            for i in range(0, n):
                v += pow(x[i], j+l)
            Sij[j, l] = Sij[l, j] = v

    if iPrint == 1:
        print("Vector and Matrix:")
        print("Si=")
        pprint(Si)
        print("Sij=")
        pprint(Sij)
        print("")

    ci = np.linalg.inv(Sij) @ Si
    ci = ci.transpose().tolist()
    return ci[0]


def main():
    print("Least-squares method for polynomial order {}".format(norder))
    print(f"norder={norder}")
    print(f"infile={infile}")

    print("")
    print(f"Read [{infile}]")
    df = pd.read_excel(infile, engine = 'openpyxl')
    labels = df.columns.to_list()
    x = df[labels[0]]
    y = df[labels[1]]
    ndata = len(x)
    minx = min(x)
    maxx = max(x)
    if minx < 0.0:
        xcal0 = minx * 1.2
    else:
        xcal0 = minx * 0.8
    if maxx < 0.0:
        xcal1 = maxx * 0.8
    else:
        xcal1 = maxx * 1.2
    xcalstep = (xcal1 - xcal0) / (ncal - 1)    
    print(f"Cal range: {xcal0} - {xcal1} at {xcalstep} step, {ncal} points")

    print("")
    print(f"Execute linear least-squares method")
    ci = mlsq(x, y, norder, iPrint = 1)

    print("LSQ function")
    print("f(x) = {}".format(ci[0]), end = '')
    for i in range(1, norder+1):
        print(" + {} * x^{}".format(ci[i], i), end = '')
    print("")


    xcal = [xcal0 + i * xcalstep for i in range(ncal)]
    ycal = []
    for i in range(ncal):
        _x = xcal[i]
        yl = ci[0]
        for k in range(1, norder+1):
            yl += ci[k] * pow(_x, k)

        ycal.append(yl)


#================================================================
# Plot
#================================================================
    fig, axes = plt.subplots(1, 2, figsize = (8, 6))
    axes[0].plot(x, y,       label = 'input', linestyle = '', marker = 'o')
    axes[0].plot(xcal, ycal, label = 'fit',   linestyle = '-')
    axes[0].tick_params(labelsize = fontsize)
    axes[0].set_xlabel('$x$', fontsize = fontsize)
    axes[0].set_ylabel('$y$', fontsize = fontsize)
    axes[0].legend(fontsize = fontsize)

    axes[1].plot(range(norder+1), ci, label = 'coeff', linestyle = '', marker = 'o')
    axes[1].tick_params(labelsize = fontsize)
    axes[1].set_xlabel('$i$', fontsize = fontsize)
    axes[1].set_ylabel('$c_i$', fontsize = fontsize)
    axes[1].legend(fontsize = fontsize)

    plt.tight_layout()
    plt.pause(0.1)

    print("")
    print("Press ENTER to terminate")
    input()


if __name__ == "__main__":
    main()
