import sys
import numpy as np
from numpy import sqrt
import openpyxl
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, Lasso, LassoCV
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt


infile  =  'random-poly-ML.xlsx'

xcal0 = None
xcal1 = None
ncal = 101

fontsize = 16


argv = sys.argv
narg = len(argv)
if narg >= 2:
    infile = argv[1]


def main():
    print("Linear regression for polynomial")
    print(f"infile={infile}")

    print("")
    print(f"Read [{infile}]")
    df = pd.read_excel(infile, engine = 'openpyxl')
    labels = df.columns.tolist()
    x_labels = labels[2:]
# 記述子
    x = df[labels[2:]]
# 目的関数
    y = df[labels[1]]

    ndata = len(df.index)
    ndescriptors = len(labels) - 2
    norder = len(labels) - 3
    print("ndata=", ndata)
    print("ndescriptors=", ndescriptors)
    print("  x_labels=", x_labels[2:])
    print("norder=", norder)
    
    print("")
    print("LinearRegression:")
    scaler = StandardScaler()
    scaler.fit(x)
    x_scaled = scaler.transform(x)
    model = LinearRegression()
    model.fit(x_scaled, y)

    y_cal = model.predict(x_scaled)
    mae  = mean_absolute_error(y, y_cal)
    mse  = mean_squared_error(y, y_cal)
    rmse = sqrt(mse)
    print(f"Mean absolute error (MAE)  : {mae}")
    print(f"Mean squared error (MSE)   : {mse}")
    print(f"Root MSE (RMSE)            : {rmse}")
    print(f"    intercept: {model.intercept_}")
    for iv in range(len(x_labels)):
        print(f"    {x_labels[iv]:>10}: {model.coef_[iv]:12.4g}")


#================================================================
# Plot
#================================================================
# プロット用のx, yの値
    x_plot = df[labels[0]]
    y_plot = df[labels[1]]

    plt.scatter(x_plot, y_plot)
    plt.plot(x_plot, y_cal, color = 'red')
    plt.xlabel('$x$', fontsize = fontsize)
    plt.ylabel('$y$', fontsize = fontsize)

    plt.pause(0.1)

    print("")
    print("Press ENTER to terminate")
    input()


if __name__ == "__main__":
    main()
