import sys
import numpy as np
from numpy import sin, cos, tan, pi
from pprint import pprint
import csv
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt


nfunc = 4
infile  =  'random-sin.xlsx'

xcal0 = None
xcal1 = None
ncal = 101

fontsize = 16


argv = sys.argv
narg = len(argv)
if narg >= 2:
    infile = argv[1]
if narg >= 3:
    nfunc  = int(argv[2])


flabel = ['1', 'cos(2x)', 'sin(2x)', 'cos(x)', 'sin(x)', 'cos(3x)', 'sin(3x)', 'x', 'x^2', 'exp(x)']
def lsqfunc(i, x):
    if i == 0:
        return 1.0
    elif i == 1:
        return cos(2.0*x)
    elif i == 2:
        return sin(2.0*x)
    elif i == 3:
        return cos(x)
    elif i == 4:
        return sin(x)
    elif i == 5:
        return cos(3.0*x)
    elif i == 6:
        return sin(3.0*x)
    elif i == 7:
        return x
    elif i == 8:
        return x**2
    elif i == 9:
        return exp(x)

def mlsq(x, y, m):
    n = len(x)
    Si  = np.empty([m, 1])
    Sij = np.empty([m, m])

    for l in range(0, m):
        v = 0.0
        for i in range(0, n):
            v += y[i] * lsqfunc(l, x[i])
        Si[l, 0] = v

    for j in range(0, m):
        for l in range(j, m):
            v = 0.0
            for i in range(0, n):
                v += lsqfunc(j, x[i]) * lsqfunc(l, x[i])
            Sij[j, l] = Sij[l, j] = v

    print("Vector and Matrix:")
    print("Si=")
    print(Si)
    print("Sij=")
    print(Sij)
    print("")

    ci = np.linalg.inv(Sij) @ Si
    ci = ci.transpose().tolist()
    return ci[0]


def main():
    print("Least-squares method for sum of {} functions".format(nfunc))
    print(f"nfunc={nfunc}")
    print(f"infile={infile}")

    print("")
    print(f"Read [{infile}]")
    df = pd.read_excel(infile, engine = 'openpyxl')
    labels = df.columns.to_list()
    x = df[labels[0]]
    y = df[labels[1]]
    ndata = len(x)
    minx = min(x)
    maxx = max(x)
    if minx < 0.0:
        xcal0 = minx * 1.5
    else:
        xcal0 = minx * 0.5
    if maxx < 0.0:
        xcal1 = maxx * 0.5
    else:
        xcal1 = maxx * 1.5
    xcalstep = (xcal1 - xcal0) / (ncal - 1)    
    print(f"Cal range: {xcal0} - {xcal1} at {xcalstep} step, {ncal} points")

    print("")
    print(f"Execute linear least-squares method")
    ci = mlsq(x, y, nfunc)

    print("LSQ function")
    print("f(x) = {:6.3g}".format(ci[0]), end = '')
    for i in range(1, nfunc):
        print(" + {:6.3g} * {}".format(ci[i], flabel[i]), end = '')
    print("")

    xcal = [xcal0 + i * xcalstep for i in range(ncal)]
    ycal = []
    for i in range(ncal):
        _x = xcal[i]
        yl = 0.0
        for k in range(nfunc):
            yl += ci[k] * lsqfunc(k, _x)

        ycal.append(yl)

#================================================================
# Plot
#================================================================
    fig, axes = plt.subplots(1, 2, figsize = (8, 6))
    axes[0].plot(x, y,       label = 'input', linestyle = '', marker = 'o')
    axes[0].plot(xcal, ycal, label = 'fit',   linestyle = '-')
    axes[0].tick_params(labelsize = fontsize)
    axes[0].set_xlabel('$x$', fontsize = fontsize)
    axes[0].set_ylabel('$y$', fontsize = fontsize)
    axes[0].legend(fontsize = fontsize)

    axes[1].plot(range(nfunc), ci, label = 'coeff', linestyle = '', marker = 'o')
    axes[1].tick_params(labelsize = fontsize)
    axes[1].set_xlabel('$i$', fontsize = fontsize)
    axes[1].set_ylabel('$c_i$', fontsize = fontsize)
    axes[1].legend(fontsize = fontsize)

    plt.tight_layout()
    plt.pause(0.1)

    print("")
    print("Press ENTER to terminate")
    input()


if __name__ == "__main__":
    main()
