import numpy as np
from numpy import sin, cos, tan, pi, exp, sqrt
from pprint import pprint

from tklib.tkobject import tkObject
import tklib.tkre as tkre
from tklib.tksci.tkoptimizeobject import tkOptimizeData
from tklib.tkutils import lvlprint, merge_attributes, joinf


class tkSimplexVertex(tkObject):
    def __init__(self, nvar, **args):
        self.nvar = nvar
        self.x = []
        for i in range(self.nvar):
            self.x.append(0.0)
        self.f = None
        self.update(**args)

    def __del__(self):
        pass
        
    def __str__(self):
        return self.ClassPath()

class tkSimplexVariable(tkOptimizeData):
    def __init__(self, x0, dx, itmax = 100, tolx = 1.0e-5, tolf = 1.0e-5,
                callback = None, lprint = 1, iprintinterval = 10, 
                alpha = 1.0, beta = 0.5, gamma = 2.0, eps = 1.0e-50, **kwargs):
        self.nvar   = len(x0)
        self.nvtx   = self.nvar + 1
        self.x0     = x0
        self.x      = x0.copy()
        self.dx     = dx
        self.itmax  = itmax
        self.tolx   = tolx
        self.tolf   = tolf
        
        self.callback = callback
        self.lprint   = lprint
        self.iprintinterval = iprintinterval

        self.alpha  = alpha
        self.alpha1 = self.alpha + 1.0
        self.beta   = beta
        self.beta1  = 1.0 - self.beta
        self.gamma  = gamma
        self.gamma1 = 1.0 - self.gamma
        self.eps    = eps

        for i in range(self.nvar):
            if abs(self.dx[i]) < eps:
                self.dx[i] = abs(self.dx[i]) + 1.0

        self.vtx = []
        for i in range(self.nvtx):
            v = tkSimplexVertex(self.nvar)
            self.vtx.append(v)

        self.xr = []
        for i in range(self.nvar):
            self.xr.append(0.0)

        self.iter   = None
        self.il     = None
        self.ih     = None
        self.fl     = None
        self.fh     = None
        self.fs     = None
        self.fe     = None
        self.fr     = None
        self.fc     = None
        
        self.status = 0 # default is normal. status > 0 indicates an error
        
        self.update(**kwargs)

    def __del__(self):
        pass
        
    def __str__(self):
        return self.ClassPath()


def simplex(func = None, x0 = None, dx = None, itmax = 100, tolx = 1.0e-5, tolf = 1.0e-5, 
        initial_simplex = 'positive',
        optdata = None, callback = None, lprint = 1, iprintinterval = 10, 
        alpha = 1.0, beta = 0.5, gamma = 2.0,eps = 1.0e-50):
    """
        Function to perform simplex optimization
        Input:
            callback: Function called at every iteration
                    Return code should be 0 for continuation,
                                          >0 will terminate iteration
    """
    
#    print("x0=", x0)
    var = tkSimplexVariable(x0, dx, itmax, tolx, tolf, 
                callback, lprint, iprintinterval,
                alpha, beta, gamma, eps)
    if optdata is not None:
        optdata.nvtx   = var.nvtx
        optdata.vtx    = var.vtx
        optdata.xr     = var.xr
        optdata.alpha  = var.alpha
        optdata.alpha1 = var.alpha1
        optdata.beta   = var.beta
        optdata.beta1  = var.beta1
        optdata.gamma  = var.gamma
        optdata.gamma1 = var.gamma1
        var = optdata
    vtx = var.vtx
    x = x0.copy()

    nvar = len(x0)

# Build initial simplex
    for j in range(nvar):
        vtx[0].x[j] = x[j]

    if initial_simplex == 'negative':
        for i in range(nvar):
            vtx[i+1].x[i] = x[i] - dx[i]
    else:
        for i in range(nvar):
            vtx[i+1].x[i] = x[i] + dx[i]

    if initial_simplex == 'symmetric':
        for i in range(nvar):
            vtx[0].x[i] = x[i] - dx[i]

    for i in range(0, var.nvtx):
#        print("i,x=", i, vtx[i].x)
        vtx[i].f = func(vtx[i].x, var)
#    exit()
        
    var.il = 0
    var.x0 = var.vtx[0].x
    var.x  = var.x0
    var.fl = func(var.x0, var)
    var.f  = var.fl


    for iter in range(0, itmax):
        var.iter = iter
        
        if callback is not None:
            var.x = var.vtx[var.il].x.copy()
            var.f = var.fl
            ret = callback(var)
            if ret > 0:
                var.status = ret
                return None, None, var

        if (iter-1) % iprintinterval == 0 and lprint <= 2:
            lvlprint(lprint, 1, "Simplex at iter {}:".format(iter))
            for i in range(var.nvtx):
                xstr = joinf(vtx[i].x, "%10.8g", ", ")
                lvlprint(lprint, 1, "  ", i, ": (", xstr, ")  f={:16.12g}".format(vtx[i].f))

        var.ih = 0
        var.fh = vtx[0].f
        var.il = 0
        var.fl = vtx[0].f
        for i in range(1, var.nvtx):
            if vtx[i].f > var.fh:
                var.ih = i
                var.fh = vtx[i].f
            if vtx[i].f < var.fl:
                var.il = i
                var.fl = vtx[i].f

        var.fs = var.fl
        for i in range(var.nvtx):
            if i != var.ih and vtx[i].f > var.fs:
                var.fs = vtx[i].f

        if (iter-1) % iprintinterval == 0:
            lvlprint(lprint, 2, 
                    "ITER={:3d}  FL, FS, FH={},{},{}".format(iter, var.fl, var.fs, var.fh))

        if tolx > eps:
            xint = 0.0
            for i1 in range(nvar):
                for i2 in range(i1 + 1, var.nvtx):
                    for j in range(nvar):
                        xdif = abs(vtx[i1].x[j] - vtx[i2].x[j])
                        if xdif > xint:
                            xint = xdif
            if tolf > eps:
                if var.fh - var.fl < tolf and xint < tolx:
                    lvlprint(lprint, 0, "(SUBR.SMPLX)1 CONVERGENCE AT ITER={}".format(iter))
                    lvlprint(lprint, 0, "  fh={} - fl={} < tolf={}, xint={} < tolx={}"
                                .format(var.fh, var.fl, tolf, xint, tolx))
                    for j in range(nvar):
                        x[j] = vtx[var.il].x[j]
                    return x, vtx[var.il].f, var
            elif xint < tolx:
                lvlprint(lprint, 0, "(SUBR.SMPLX)2 CONVERGENCE AT ITER={}".format(iter))
                return vtx[var.il].x, vtx[var.il].f, var
        else:
            if tolf > eps and fh - fl < tolf:
                lvlprint(lprint, 0, "(SUBR.SMPLX)3 CONVERGENCE AT ITER={}".format(iter))
                return vtx[var.il].x, vtx[var.il].f, var

        for j in range(nvar):
            s = 0.0
            for i in range(0, var.nvtx):
                s += vtx[i].x[j]
            x[j] = (s - vtx[var.ih].x[j]) / nvar

        for j in range(nvar):
            var.xr[j] = var.alpha1 * x[j] - var.alpha * vtx[var.ih].x[j]
        var.fr = func(var.xr, var)
        if (iter-1) % iprintinterval == 0:
            lvlprint(lprint, 2, "REFLECTION APPLIED.  FR ={}".format(var.fr), "  XR(*)=", var.xr)

        if var.fr <= var.fs:
            if var.fr < var.fl:
                for j in range(nvar):
                    x[j] = var.gamma * var.xr[j] + var.gamma1 * x[j]
                var.fe = func(x, var)
                if (iter-1) % iprintinterval == 0:
                    lvlprint(lprint, 2, "EXPANSION APPLIED.   FE ={}".format(var.fe), "   XE(*)=", x)

                if var.fe < var.fr:
                    for j in range(nvar):
                       vtx[var.ih].x[j] = x[j]
                    vtx[var.ih].f = var.fe
                    if (iter-1) % iprintinterval == 0:
                        lvlprint(lprint, 2, "XH(*)1 IS REPLACED BY XE(*)")
                else:
                    for j in range(nvar):
                        vtx[var.ih].x[j] = var.xr[j]
                vtx[var.ih].f = var.fr;
                if (iter-1) % iprintinterval == 0:
                    lvlprint(lprint, 2, "XH(*)2 IS REPLACED BY XR(*)")
            else:
                for j in range(nvar):
                    vtx[var.ih].x[j] = var.xr[j]
                vtx[var.ih].f = var.fr
                if (iter-1) % iprintinterval == 0:
                    lvlprint(lprint, 2, "XH(*)3 IS REPLACED BY XR(*)")

        else:
            if var.fr < var.fh:
                for j in range(nvar):
                    vtx[var.ih].x[j] = var.xr[j]
                vtx[var.ih].f = var.fr
                if (iter-1) % iprintinterval == 0:
                    lvlprint(lprint, 2, "XH(*)4 IS REPLACED BY XR(*)")
            xtemp = np.empty(nvar)
            for j in range(nvar):
                xtemp[j] = var.beta * vtx[var.ih].x[j] + var.beta1 * x[j]
            x = xtemp
            var.fc = func(x, var)
            if (iter-1) % iprintinterval == 0:
                lvlprint(lprint, 2,"CONTRACTION APPLIED. FC ={}".format(var.fc), "  XC(*)=", x)
            if var.fc < var.fh:
                for j in range(nvar):
                    vtx[var.ih].x[j] = x[j]
                vtx[var.ih].f = var.fc
                if (iter-1) % iprintinterval == 0:
                    lvlprint(lprint, 2, "XH(*)5 IS REPLACED BY XC(*)")
            else:
                if iter == itmax:
                    lvlprint(lprint, 0, "(SUBR.SMPLX) ITERATION TERMINATED DUE TO ITMAX")
                    for j in range(nvar):
                        return vtx[var.il].x, vtx[var.il].f,var

                for i in range(var.nvtx):
                    x = np.empty(nvar)
                    for j in range(nvar):
                        x[j] = 0.5 * (vtx[i].x[j] + vtx[var.il].x[j])
                    vtx[i].x = x
                    vtx[i].f = func(x, var)
                if (iter-1) % iprintinterval == 0:
                    lvlprint(lprint, 2, "REDUCTION APPLIED AROUND THE POINT: ", vtx[var.il].x)

    if (iter-1) % iprintinterval == 0:
        lvlprint(lprint, 0, "\nError: ITERATION TERMINATED DUE TO ITMAX")
    return vtx[var.il].x, vtx[var.il].f, var
