import sys
import numpy as np
from numpy import sin, cos, tan, pi, exp, sqrt


from tklib.tkargs import tkArgs
from tklib.tksci.tkoptimize import tkOptimize
from tklib.tkgraphic.tkplot import tkPlot


# target function
#functype = 'ellipsoid'
#functype = 'ellipsoid2'
#functype = 'circle'
functype = 'general'


#algorism: simplex, sd, cg, newton, broyden, dfp, bfgs
algorism = 'simplex'
#algorism = 'sd'
#algorism = 'cg'
#algorism = 'newton'
#algorism = 'broyden'
#algorism = 'dfp'
#algorism = 'bfgs'

# for line search '' or 'none'/'newton', 'one', 'simple', 'exact', 'golden', 'armijo'
lsmode = 'armijo'


# optimization parameters
x0    = [0.0, 0.0]
optid = [  1,   1]
#for simplex
#dx = np.array([3.0, 3.0])  # range to make an initial simplex
dx = np.array([3.0, 3.0])  # range to make an initial simplex


# colormap for contour plot: line, hsv, cool, iwnter, gray, gist_gray, cividis etc
colormap = 'cool'
#colormap = 'hsv'
#colormap = 'Spectral'

# number of contours
nlevels = 51


if len(sys.argv) >= 2:
    x0[0] = float(sys.argv[1])

if len(sys.argv) >= 3:
    x0[1] = float(sys.argv[2])

if len(sys.argv) >= 4:
    algorism = sys.argv[3]

if algorism == 'newton':
    lsmode = 'newton'

if len(sys.argv) >= 5:
    lsmode = sys.argv[4]

if len(sys.argv) >= 6:
    functype = sys.argv[5]

if len(sys.argv) >= 7:
    colormap = sys.argv[6]

if len(sys.argv) >= 8:
    nlevels = int(sys.argv[7])

if len(sys.argv) >= 9:
    initial_simplex = sys.argv[8]

if len(sys.argv) >= 10:
    simplex_scale = float(sys.argv[9])


# for global optimization
nmaxiter = 200
tolx   = 1.0e-5
tolf   = 1.0e-5
iprintinterval = 1
print_level    = 4

# for Newton-Raphson method
dump  = 0.3

# line search parameters
ls_nmaxiter = 100

# for lsmode == 'exact'
ls_h = 1.0e-3  # x step to calculate first and second derivatives

# for lsmode == 'golden'
ls_alphaeps = 1.0e-2
ls_xrange = 0.5

if 'ellipsoid' in functype or 'circle' in functype:
    ls_alpha = 0.03
else:
    ls_alpha = 0.01

if lsmode == 'golden':
    ls_alpha = 1.0
elif lsmode == 'armijo':
    ls_alpha = 1.0


# graph plot parameters
fplot = 1
ngdata = 51
xgmin = -4.0
xgmax =  4.0
ygmin = -4.0
ygmax =  4.0
tsleep = 0.2

# file configuration
parameter_path = 'params.prm'
finalparameter_path = 'final_params.prm'


# first derivative of func(x)
def ediff1(i, x, optdata = None):
    if i == 0:
        return 2.0 * x[0]
    if i == 1:
        return 18.0 * x[1]

def gdiff1(i, x, optdata = None):
    if i == 0:
        return - 10.0 - 60.0 * x[0] + 4.5 * x[0]*x[0] + 12.0 * x[0]*x[0]*x[0] \
                + 3.0 * x[1] * x[1]
    if i == 1:
        return 30.0 - 60.0 * x[1] + 12.0 * x[1]*x[1]*x[1] \
                + 6.0 * x[0] * x[1]

# for Newton-Raphson method, second derivative of func(x)
def gdiff2(i, j, x, optdata = None):
    if i == 0 and j == 0:
        return -60.0 + 9.0 * x[0] + 36.0 * x[0]*x[0]
    if i == 0 and j == 1:
        return 6.0 * x[1]
    if i == 1 and j == 0:
        return 6.0 * x[1]
    if i == 1 and j == 1:
        return -60.0 +36.0 * x[1]*x[1] + 6.0 * x[0]

def gfunc(x, optdata = None):
    return -3.0 - 10.0 * x[0] - 30.0 * x[0]*x[0] + 1.5 * x[0]*x[0]*x[0] + 3.0 * x[0]*x[0]*x[0]*x[0] \
                + 30.0 * x[1] - 30.0 * x[1]*x[1] + 3.0 * x[1]*x[1]*x[1]*x[1] \
                + 3.0 * x[0] * x[1] * x[1]

def ediff2(i, j, x, optdata = None):
    if i == 0 and j == 0:
        return 2.0
    if i == 0 and j == 1:
        return 0.0
    if i == 1 and j == 0:
        return 0.0
    if i == 1 and j == 1:
        return 18.0

def efunc(x, optdata = None):
    return x[0]*x[0] + 9.0*x[1]*x[1]


ae2xx =  1.0
ae2xy = -2.0
ae2yy =  3.0
def ediff2b(i, j, x, optdata = None):
    if i == 0 and j == 0:
        return 2.0 * acxx
    if i == 0 and j == 1:
        return ae2xy
    if i == 1 and j == 0:
        return ae2xy
    if i == 1 and j == 1:
        return 2.0 * acyy

def ediff1b(i, x, optdata = None):
    if i == 0:
        return 2.0 * ae2xx * x[0] + ae2xy * x[1]
    if i == 1:
        return ae2xy * x[0] + 2.0 * ae2yy * x[1]

def efuncb(x, optdata = None):
    return ae2xx * x[0]*x[0] + ae2xy * x[0]*x[1] + ae2yy * x[1]*x[1]

acxx =  3.0
def ediff2c(i, j, x, optdata = None):
    if i == 0 and j == 0:
        return 2.0 * acxx
    if i == 0 and j == 1:
        return 0.0
    if i == 1 and j == 0:
        return 0.0
    if i == 1 and j == 1:
        return 2.0 * acxx

def ediff1c(i, x, optdata = None):
    if i == 0:
        return 2.0 * acxx * x[0]
    if i == 1:
        return 2.0 * acxx * x[1]

def efuncc(x, optdata = None):
    return acxx * (x[0]*x[0] + x[1]*x[1])


# global function variables
if functype == 'ellipsoid':
    diff1 = ediff1
    diff2 = ediff2
    func  = efunc
elif functype == 'ellipsoid2':
    diff1 = ediff1b
    diff2 = ediff2b
    func  = efuncb
elif functype == 'circle':
    diff1 = ediff1c
    diff2 = ediff2c
    func  = efuncc
else:
    diff1 = gdiff1
    diff2 = gdiff2
    func  = gfunc

# callback function. return code > 0 indicates error and simplex iteration will be terminated
def cfunc(optdata):
#    print("callback at iter={}".format(optdata.iter))
    x0 = optdata.RecoverParams(optdata.x0, optdata.x0_all, optdata.optid)
    if optdata.fplot == 1:
        graph = optdata.graph
        ax    = graph.axes[0]
        ax2   = graph.axes[1]
        if optdata.method == 'simplex':
            vtx = optdata.vtx
            x = []
            y = []
            for i in range(optdata.nvtx):
                x.append(vtx[i].x[0])
                y.append(vtx[i].x[1])
            x.append(vtx[0].x[0])
            y.append(vtx[0].x[1])
            graph.plot(0, x, y, color = 'blue', marker = '', linestyle = '-')
        else:
            xt = ax.xlist[0]
            yt = ax.ylist[0]
            xt.append(x0[0])
            yt.append(x0[1])
            graph.set_data(0, 0, 0, xt, yt)

        graph.pause(optdata.tsleep)
#        graph.sleep(optdata.tsleep)
    return 0


def main():
    print("Find minimum point by steepest-descend / conjugate gradient / simplex methods")
    print("")
    print("x0= ({}, {})".format(x0[0], x0[1]))
    print("algorism=", algorism)
    print(f"  For simplex:")
    print(f"    initial_simplex={initial_simplex}")
    print(f"    scale={simplex_scale}")

if len(sys.argv) >= 10:
    simplex_scale = float(sys.argv[9])

    print("lsmode=", lsmode)
    print("functype=", functype)
    print("colormap=", colormap)
    print("")

    for i in range(len(dx)):
        dx[i] = simplex_scale

    opt = tkOptimize()
    opt.add_parameter("x", x0[0], dx[0], optid[0])
    opt.add_parameter("y", x0[1], dx[1], optid[1])
    opt.set_method(algorism, lsmode, initial_simplex = initial_simplex)
    opt.set_functions(fitfunc = func, func = func, diff1func = diff1, diff2func = diff2)
    opt.initialize(callback = cfunc)
#    opt.initialize(
#            nmaxiter = nmaxiter, tolx = tolx, tolf = tolf,
#            callback = cfunc, print_level = print_level, iprintinterval = iprintinterval,
#            dump = dump,
#            ls_xrange = ls_xrange, ls_h = ls_h, ls_alpha = ls_alpha, ls_dump = dump, 
#            ls_nmaxiter = ls_nmaxiter, ls_alphaeps = ls_alphaeps
#            )
    var = opt.make_optdata(tsleep = tsleep)

#    inf = opt.read_parameters(parameter_path, update_params = 0)
#    print("parameters in [{}]".format(parameter_path))
#    for key, val in inf.items():
#        print("{}: {}".format(key, val))
#    print("save to [{}]".format(parameter_path))
#    opt.save_parameters(parameter_path)
#    exit()


    print("")
    print("Initial parameters:")
    opt.print_parameters()


# plot surface graph
    if fplot == 1:
        xgstep = (xgmax - xgmin) / (ngdata - 1)
        ygstep = (ygmax - ygmin) / (ngdata - 1)
        xg = np.empty([ngdata, ngdata])
        yg = np.empty([ngdata, ngdata])
        zg = np.empty([ngdata, ngdata])
        for ix in range(ngdata):
            for iy in range(ngdata):
                xg[ix, iy] = xgmin + ix * xgstep
                yg[ix, iy] = ygmin + iy * ygstep
                zg[ix, iy] = func([xg[ix][iy], yg[ix][iy]])

        graph = tkPlot(figsize = (10, 5))
        graph.add_subplot(1, 2, 0)
        graph.add_subplot(1, 2, 1, projection = '3d')

        f = func(x0)
        print("x0 = ({}, {}): f = {}".format(x0[0], x0[1], f))
        xt = []
        yt = []
        xt.append(x0[0])
        yt.append(x0[1])
        graph.plot(0, xt, yt, color = 'blue', linestyle = '-', linewidth = 0.5,
                            fillstyle = 'full', marker = 'o', markersize = 5)
        if colormap == 'line':
            graph.plot_contour(0, xg, yg, zg, levels = nlevels, 
                                cmap = 'line', colors = ['black'], linewidths = 0.5, aspect = 'equal')
        else:
            graph.plot_contour(0, xg, yg, zg, levels = nlevels, cmap = colormap, aspect = 'equal')
#        ax.set_aspect('equal')

        graph.set_axtitle(0, 'contour')

        graph.plot_wireframe(1, xg, yg, zg, rstride = 2, cstride = 2)

        graph.pause()
#        graph.show()

        var.graph   = graph
        var.fplot   = fplot


# calculate initial parameters
    xmin, fmin, optdata = opt.optimize()
    print("")
    print("Optimized parameters:")
    opt.print_parameters(x0 = xmin, f = fmin)

    if opt.save_parameters(finalparameter_path):
        print("")
        print("Final resuls are stored to [{}]".format(finalparameter_path))
    else:
        print("")
        print("Error: Can not save the results to [{}]".format(finalparameter_path))


    if fplot == 1:
        print("Press enter to terminate:", end = '')
        ret = input()


if __name__ == "__main__":
    main()

