import sys
import csv
import numpy as np
from numpy import exp, log, log10, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, sqrt, abs
import pandas as pd
from scipy.optimize import minimize
import matplotlib.pyplot as plt


from tklib.tksci.tksci import h, h_bar, hbar, e, kB, NA, c, pi, pi2, torad, todeg, basee
from tklib.tksci.tksci import me, mp, mn, u0, e0, e2_4pie0, a0, R, F, g
from tklib.tksci.tksci import acos, asin, atan, cosh, sinh, tanh
from tklib.tksci.tksci import degcos, degsin, degtan, degacos, degasin, degatan
from tklib.tksci.tksci import eVTonm, nmToeV
from tklib.tksci.tksci import factorial, Factorize, Gaussian, Lorentzian, combination, gamma


#nelder-mead    Downhill simplex
#powell         Modified Powell
#cg             conjugate gradient (Polak-Ribiere method)
#bfgs           BFGS法
#newton-cg      Newton-CG
#trust-ncg      信頼領域 Newton-CG 法
#dogleg         信頼領域 dog-leg 法
#L-BFGS-B’ (see here)
#TNC’ (see here)
#COBYLA’ (see here)
#SLSQP’ (see here)
#trust-constr’(see here)
#dogleg’ (see here)
#trust-exact’ (see here)
#trust-krylov’ (see here)
method = "nelder-mead"


#==========================================
# Source parameters to be fitted
#==========================================
func = '2.0 * (x[0] - 3.0)**2 + (x[1] - 1.0)**4 + 2.0'
x0s = "0.0,0.0"

h = 0.01

maxiter = 100
tol = 1.0e-5


#==========================================
# Graph parameters
#==========================================
fplot  = 1
ngdata = 201
xgmin  = -4.0
xgmax  =  4.0
tsleep = 0.3


argv = sys.argv
n = len(argv)
if n >= 2:
    method = argv[1]
if n >= 3:
    func = argv[2]
if n >= 4:
    x0s = argv[3]
if n >= 5:
    h = float(argv[4])
if n >= 6:
    maxiter = int(argv[5])
if n >= 7:
    tol = float(argv[6])
if n >= 8:
    xgmin = float(argv[7])
if n >= 9:
    xgmax = float(argv[8])

x0s = [float(s) for s in x0s.split(',')]


def minimize_func(x):
    y = eval(func, globals(), {"x": x})
    return y

# １次微分を定義するとcgやbfgsなどの勾配法を使える
def diff1(x):
    diff = x.copy()
    nvar = len(x)

    for i in range(nvar):
        xx = x.copy()
        xx[i] = x[i] - h
        ym = minimize_func(xx)
        
        xx[i] = x[i] + h
        yp = minimize_func(xx)
        diff[i] = (yp - ym) / 2.0 / h

    return diff
    
# callbackを使うと、最適化過程をモニターできる
# 引数は変数のリストだけが渡される
# 反復回数などはglobal変数で保持する
iter = 0
ax   = None
xg   = None
def callback(xk):
    global iter

    fmin = minimize_func(xk)
    print(f"callback {iter}: xk={xk}  func={fmin}")
    iter += 1

# グラフに更新した点を追加
    if fplot == 1:
        ycal = [minimize_func([_x, *xk[1:]]) for _x in xg]
        ax.plot(xg, ycal, color = 'blue', linestyle = '-', linewidth = 0.2)
        ax.plot(xk[0], minimize_func(xk), linestyle = '', marker = 'o', markersize = 5.0, markerfacecolor = 'red')
        plt.pause(tsleep)


#==========================================
# Main routine
#==========================================
def main():
    global x0s, tol, maxiter
    global ax, xg

    print("")
    print("Minimize given function")
    print(f"method {method}")
    print(f"func {func}")
    print(f"graph range {xgmin} - {xgmax}")
    print(f"initial values {x0s}")
    print(f"tol {tol}")

    if '***' in method:
        input("\nError: Choose method >>\n")
        exit()

# グラフに表示する関数
    xgstep = (xgmax - xgmin) / (ngdata - 1)
    xg = [xgmin + xgstep * i for i in range(ngdata)]
    yini = [minimize_func([_x, *x0s[1:]]) for _x in xg]
#    for i in range(len(xg)):
#        print(xg[i], yini[i])

#関数のグラフ
    if fplot == 1:
        figure = plt.figure(figsize = (5, 5))
        ax = figure.add_subplot(1, 1, 1)
        ax.plot(xg, yini, color = 'black', linestyle = '-', linewidth = 0.5)
        ax.plot(x0s[0], minimize_func(x0s), linestyle = '', marker = 'o', markersize = 5.0, markerfacecolor = 'black')

        plt.pause(0.001)

    print("")
    print("Minimize:")
    res = minimize(minimize_func, x0s, jac = diff1, method = method, tol = tol, callback = callback,
                options = {'maxiter':maxiter, "disp":True})
#    print("")
#    print(res)
    print("")
    if res.success:
        print(f"Function [{func}] takes the minimum")
        print(f"   at y={res.fun}")
        print(f"   with x={res.x}")
        print(f" iteration: {res.nit}")
    else:
        print(f"Function did not converge")
        print(res)


    if fplot == 1:
        print("Press ENTER to terminate:", end = '')
        ret = input()


if __name__ == "__main__":
    main()
