import sys
import numpy as np
from numpy import exp
import openpyxl
import pandas as pd
import matplotlib.pyplot as plt


infile = 'random-poly-Gauss.xlsx'
nG = 51
x0 = []
wG = 0.3
lmda = 0.0

xcal0 = None
xcal1 = None
ncal = 101

fontsize = 16


argv = sys.argv
narg = len(argv)
if narg >= 2:
    infile = argv[1]
if narg >= 3:
    nG  = int(argv[2])
if narg >= 4:
    wG  = float(argv[3])
if narg >= 5:
    lmda  = float(argv[4])


def lsqfunc(i, x):
    global x0, wG

    a = (x - x0[i]) / wG
    return exp(-a**2)

def Ridge(x, y, m, lmda = 0.0):
    n = len(x)
    Si  = np.empty([m, 1])
    Sij = np.empty([m, m])

    for l in range(0, m):
        v = 0.0
        for i in range(0, n):
            v += y[i] * lsqfunc(l, x[i])
        Si[l, 0] = v

    for j in range(0, m):
        for l in range(j, m):
            v = 0.0
            for i in range(0, n):
                v += lsqfunc(j, x[i]) * lsqfunc(l, x[i])
            Sij[j, l] = Sij[l, j] = v
        Sij[j, j] += lmda

    print("Vector and Matrix:")
    print("Si=")
    print(Si)
    print("Sij=")
    print(Sij)
    print("")

    ci = np.linalg.inv(Sij) @ Si
    ci = ci.transpose().tolist()
    return ci[0]


def main():
    global x0, wG

    print("Ridge Gaussian regression")
    print(f"infile={infile}")
    print(f"nGauss={nG}")
    print(f"  width={wG}")
    print(f"Ridge lambda={lmda}")

    print("")
    print(f"Read [{infile}]")
    df = pd.read_excel(infile, engine = 'openpyxl')
    labels = df.columns.to_list()
    x = df[labels[0]]
    y = df[labels[1]]
    ndata = len(x)

    xcal0 = min(x)
    xcal1 = max(x)
    xcalstep = (xcal1 - xcal0) / (ncal - 1)

    xGmin = xcal0
    xGmax = xcal1
    xGstep = (xGmax - xGmin) / (nG - 1)
    x0 = [xGmin + i * xGstep for i in range(nG)]
    
    print("")
    print(f"Execute linear least-squares method")
    ci = Ridge(x, y, nG, lmda)

    for i in range(nG):
        print(f" x0[{i}]={x0[i]:6.3g}: c[{i}]={ci[i]:6.3g}")

    xcal = [xcal0 + i * xcalstep for i in range(ncal)]
    ycal = []
    for i in range(ncal):
        _x = xcal[i]
        yl = 0.0
        for k in range(nG):
            yl += ci[k] * lsqfunc(k, _x)

        ycal.append(yl)

#================================================================
# Plot
#================================================================
    fig, axes = plt.subplots(1, 2, figsize = (8, 6))
    axes[0].plot(x, y,       label = 'input', linestyle = '', marker = 'o')
    axes[0].plot(xcal, ycal, label = 'fit',   linestyle = '-')
    axes[0].tick_params(labelsize = fontsize)
    axes[0].set_xlabel('$x$', fontsize = fontsize)
    axes[0].set_ylabel('$y$', fontsize = fontsize)
    axes[0].legend(fontsize = fontsize)

    axes[1].plot(range(nG), ci, label = 'coeff', linestyle = '-', linewidth = 0.5, marker = 'o')
    axes[1].tick_params(labelsize = fontsize)
    axes[1].set_xlabel('$i$', fontsize = fontsize)
    axes[1].set_ylabel('$c_i$', fontsize = fontsize)
#    axes[1].legend(fontsize = fontsize)

    plt.tight_layout()
    plt.pause(0.1)

    print("")
    print("Press ENTER to terminate")
    input()


if __name__ == "__main__":
    main()
