import numpy as np
from numpy import sin, cos, tan, pi, exp, sqrt
from pprint import pprint
import sys

from tklib.tkobject import tkObject
from tklib.tkutils import lvlprint, merge_attributes, joinf

def ReduceOptParams(allvars, optid):
    optvars = []
    for i in range(len(allvars)):
        if optid[i]:
            optvars.append(allvars[i])
    return optvars

def RecoverParams(optvars, allvars, optid):
    vars = []
    c = 0
    for i in range(len(allvars)):
        if optid[i]:
            vars.append(optvars[c])
            c += 1
        else:
            vars.append(allvars[i])
    return vars


class tkOptimizeData(tkObject):
    def __init__(self, x0, **kwargs):
        self.nvar   = len(x0)
        self.x0     = x0
        self.x      = x0.copy()
        self.status = 0 # default is normal. status > 0 indicates an error
        self.update(**kwargs)

    def __del__(self):
        pass
        
    def __str__(self):
        return self.ClassPath()


    def ReduceOptParams(self, allvars, optid):
        return ReduceOptParams(allvars, optid)

    def RecoverParams(self, optvars, allvars, optid):
        return RecoverParams(optvars, allvars, optid)


# DFP (Davidon-Fletcher-Powell)
def diff2array_dfp(x, optdata):
    nvar = optdata.nvar
    
    I = np.identity(nvar)
    diffk = np.empty(nvar)
    for i1 in range(nvar):
       diffk[i1] = optdata.diff1func(i1, x, optdata)
    B = None
    try:
        B = optdata.B
    except:
        pass

    if B is None:
        B = I
    else:
        dxk = np.array(x - optdata.xm)
        yk = diffk - optdata.diffkm
        inn    = np.inner(yk, dxk)
        if abs(inn) >= 1.0e-6: 
            inn_inv = 1.0 / inn
            ykdxkt = np.outer(yk, dxk)
            dxkykt = np.outer(dxk, yk)
            ykyk   = np.outer(yk, yk)
            B = (I - inn_inv * ykdxkt) @ B @ (I - inn_inv * dxkykt) + inn_inv * ykyk
        else:
            lvlprint(optdata.print_level, 1, "tkoptimizeobject.diff2array_dfp: Reset to SD")
            B = I
            pass

    optdata.B      = B
    optdata.xm     = x
    optdata.diffkm = diffk
    return optdata.B

def diff2array_dfpH(x, optdata):
    nvar = optdata.nvar
    
    I = np.identity(nvar)
    diffk = np.empty(nvar)
    for i1 in range(nvar):
       diffk[i1] = optdata.diff1func(i1, x, optdata)
    H = None
    try:
        H = optdata.H
    except:
        pass

    if H is None:
        H = I
    else:
        dxk = np.array(x - optdata.xm)
        yk  = diffk - optdata.diffkm
        inn = np.inner(yk, dxk)
        if abs(inn) >= 1.0e-6: 
            dxdx = np.outer(dxk, dxk)
            Hy   = H @ yk.transpose()
            yHy  = yk @ Hy
            HyyH = np.outer(Hy, Hy)
            H = H + dxdx / inn - HyyH / yHy
        else:
            lvlprint(optdata.print_level, 1, "tkoptimizeobject.diff2array_dfpH: Reset to SD")
            H = I
            pass

    optdata.H      = H
    optdata.xm     = x
    optdata.diffkm = diffk
    return optdata.H


# BFGS (Broyden-Fletcher-Goldfarb-Shanno)
def diff2array_bfgsH(x, optdata):
    nvar = optdata.nvar
    
    I = np.identity(nvar)
    diffk = np.empty(nvar)
    for i1 in range(nvar):
       diffk[i1] = optdata.diff1func(i1, x, optdata)
    H = None
    try:
        H = optdata.H
    except:
        pass

    if H is None:
        H = I
    else:
        dxk  = np.array(x - optdata.xm)
        yk   = diffk - optdata.diffkm
        inn  = np.inner(yk, dxk)
        if abs(inn) >= 1.0e-6: 
            ykdxkt = np.outer(yk, dxk)
            Imyx   = I - ykdxkt / inn
            HImyx  = H @ Imyx
            dxdx   = np.outer(dxk, dxk)
            H = Imyx.transpose() @ HImyx + dxdx / inn
        else:
            print("tkoptimizeobject.diff2array_bfgs: Reset to SD")
            H = I
            pass

    optdata.H      = H
    optdata.xm     = x
    optdata.diffkm = diffk
    return optdata.H

def diff2array_bfgs(x, optdata):
    nvar = optdata.nvar
    
    I = np.identity(nvar)
    diffk = np.empty(nvar)
    for i1 in range(nvar):
       diffk[i1] = optdata.diff1func(i1, x, optdata)
    B = None
    try:
        B = optdata.B
    except:
        pass

    if B is None:
        B = I
    else:
        dxk  = np.array([x[i] - optdata.xm[i] for i in range(nvar)])
        yk   = diffk - optdata.diffkm
        inn  = np.inner(yk, dxk)
        if abs(inn) >= 1.0e-6: 
            ykyk = np.outer(yk, yk)
            dxkt    = dxk.transpose()
            Bdxk    = B @ dxkt
            dxkBdxk = dxk @ Bdxk
            BdxdxB = np.outer(Bdxk, Bdxk)
            B = B + ykyk / inn - BdxdxB / dxkBdxk
        else:
            print("tkoptimizeobject.diff2array_bfgs: Reset to SD")
            B = I
            pass

    optdata.B      = B
    optdata.xm     = x
    optdata.diffkm = diffk
    return optdata.B

# Broyden
def diff2array_broyden(x, optdata):
    nvar = optdata.nvar
    
    I = np.identity(nvar)
    diffk = np.empty(nvar)
    for i1 in range(nvar):
       diffk[i1] = optdata.diff1func(i1, x, optdata)
    B = None
    try:
        B = optdata.B
    except:
        pass

    if B is None:
        B = I
    else:
        dxk = np.array([x[i] - optdata.xm[i] for i in range(nvar)])
        yk  = diffk - optdata.diffkm
        inn     = np.inner(dxk, dxk)
        if abs(inn) >= 1.0e-6: 
            dxkt    = dxk.transpose()
            Bdxk    = (B @ dxkt).transpose()
            B = B + np.outer(yk - Bdxk, dxk) / inn
        else:
            print("tkoptimizeobject.diff2array_broyden: Reset to SD")
            B = I
            pass

    optdata.B      = B
    optdata.xm     = x
#    optdata.dxkm   = dxk
    optdata.diffkm = diffk
    return optdata.B

def diff2array_broydenH(x, optdata):
    nvar = optdata.nvar
    
    I = np.identity(nvar)
    diffk = np.empty(nvar)
    for i1 in range(nvar):
       diffk[i1] = optdata.diff1func(i1, x, optdata)

    H = None
    try:
        H = optdata.H
    except:
        pass

    if H is None:
        H = I
    else:
        dxk  = np.array([x[i] - optdata.xm[i] for i in range(nvar)])
        yk   = diffk - optdata.diffkm
        dxkt = dxk.reshape(1, nvar).transpose()
        ykt  = yk.reshape(1, nvar).transpose()
        Hykt = (H @ ykt).transpose()[0]
        dxkH = dxk @ H
        dxkHykt = np.inner(dxk, Hykt)
        H = H + np.outer(dxk - Hykt, dxkH) / dxkHykt

    optdata.H      = H
    optdata.xm     = x
    optdata.diffkm = diffk
    return optdata.H


def searchdir_broydenH(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)

    Hij = diff2array_broydenH(x, optdata)
    optdata.Si  = Si
    optdata.Hij = Hij
    optdata.dx  = (-Hij @ Si.reshape(n, 1)).transpose()[0]
    return optdata.dx, icg, gradfkm, dkm

def searchdir_broyden(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)

    Sij = diff2array_broyden(x, optdata)
    optdata.Si  = Si
    optdata.Sij = Sij
    optdata.dx  = -np.linalg.solve(Sij, Si)
    return optdata.dx, icg, gradfkm, dkm

def searchdir_dfpH(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)

    Hij = diff2array_dfpH(x, optdata)
    optdata.Si  = Si
    optdata.Hij = Hij
    optdata.dx  = (-Hij @ Si.reshape(n, 1)).transpose()[0]
    return optdata.dx, icg, gradfkm, dkm

def searchdir_dfp(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)
    Sij = diff2array_dfp(x, optdata)
    optdata.Si  = Si
    optdata.Sij = Sij
    optdata.dx  = -np.linalg.solve(Sij, Si)
    return optdata.dx, icg, gradfkm, dkm

def searchdir_bfgsH(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)

    Hij = diff2array_bfgsH(x, optdata)
    optdata.Si  = Si
    optdata.Hij = Hij
    optdata.dx  = (-Hij @ Si.reshape(n, 1)).transpose()[0]
    return optdata.dx, icg, gradfkm, dkm

def searchdir_bfgs(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)
    Sij = diff2array_bfgs(x, optdata)
    optdata.Si  = Si
    optdata.Sij = Sij
    optdata.dx  = -np.linalg.solve(Sij, Si)
    return optdata.dx, icg, gradfkm, dkm

# for Newton-Raphson
def diff2array_newton(x, optdata):
    diff2func = optdata.diff2func
    B = np.empty((optdata.nvar, optdata.nvar))
    for i in range(optdata.nvar):
        for j in range(i, optdata.nvar):
            v = diff2func(i, j, x, optdata)
            B[i, j] = B[j, i] = v
    optdata.Sij = B
    return B


def searchdir_newton(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    Si = np.empty(n)
    for i in range(0, n):
       Si[i] = diff1func(i, x, optdata)
    Sij = diff2array_newton(x, optdata)

    optdata.Si  = Si
    optdata.Sij = Sij

    allow_localmaximum = 1
    
    dump = optdata.dump
    tr = np.trace(Sij)
    if allow_localmaximum:
        Sijmax = np.max(abs(Sij))
        if tr < 1.0 or tr > 1.0:
            Sij += dump * tr * np.eye(n)
        elif tr < 0.0:
            Sij -= dump * np.eye(n)
        else:
            Sij += dump * np.eye(n)
    else:
        if tr < 0.0:
            dump *= 1.0 + abs(tr)
        for i1 in range(100):
            Sij += dump * np.eye(n)
            tr = np.trace(Sij)
#            print("i1,tr=", i1, tr)
            if tr > 0.0:
                break
    
    optdata.dx  = -np.linalg.solve(Sij, Si)
    optdata.Sijdumped = Sij
    return optdata.dx, icg, gradfkm, dkm

def searchdir_sd(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    gradfk = np.empty(n)
    for i in range(0, n):
# gradient of the target function
       gradfk[i] = diff1func(i, x, optdata)

    optdata.dx  = -gradfk
    return optdata.dx, icg, gradfkm, dkm

def searchdir_cg(x = None, diff1func = None, icg = None, gradfkm = None, dkm = None, optdata = None):
    n = len(x)
    gradfk = np.empty(n)
#    print("n=", n)
    for i in range(0, n):
# gradient of the target function
       gradfk[i] = optdata.diff1func(i, x, optdata)

    if icg == 0:
# search direction
        dk = -gradfk
    else:
        yk = gradfk - gradfkm
        inn1 = np.inner(gradfk, yk)
        inn2 = np.inner(dkm, yk)
# search direction
        dk = -gradfk + inn1 / inn2 * dkm
    gradfkm = gradfk.copy()
    dkm     = dk.copy()
    icg += 1
# reset if the number of conjugate gradients exceeds the dimension of the parameters
    if icg >= n:
        icg = 0

    optdata.dx  = dk
    return optdata.dx, icg, gradfkm, dkm


def linesearch_armijo(func, x0, sdir, 
            xrange = None, h = None, alpha = None, dump = None, nmaxiter = None, 
            alphaeps = None, feps = None,
            optdata = None):
    if alphaeps is None:
        alphaeps = optdata.ls_alphaeps
    if feps is None:
        feps = optdata.ls_feps

    xi  = 0.5   # 0 < xi < 1
    tau = 0.5   # 0 < tau < 1
    dr2 = np.inner(sdir, sdir)
    f0 = func(x0, optdata)
    fprev = f0
    alphamin = alpha
    xmin = x0.copy()
    fmin = f0
    for i in range(nmaxiter):
        dx = alpha * sdir
        fp = func(x0 + dx, optdata)
        fxi = f0 - xi * alpha * dr2
        xstr = joinf(x0 + dx, "%12.6g", ", ")
        lvlprint(optdata.print_level, 4, "  lsarmijo:{}".format(i), 
                    "  alpha={:12.8g}  x=({})  fxi={:12.8g}  fp={:12.8g}".format(alpha, xstr, fxi, fp))
#        print("i={:3d}  alpha={:6.3f}  fp={:6.3f}  fxi={:6.3f}".format(i, alpha, fp, fxi), "   dx=", dx)

        if fp <= f0 and fp <= fmin:
            alphamin = alpha
            xmin = x0 + dx
            fmin = fp

        if fp <= f0 and fp <= fxi:
            return i, x0 + dx, dx, alpha
        if fprev - fp < feps and alpha < alphaeps:
            return i, xmin, alphamin * sdir, alphamin
#            return i, x0, 0.0 * dx, 0.0

        alpha *= tau
        fprev = fp

    return -1, x0 + dx, dx, alpha

def linesearch_golden(func, x0, sdir,
            xrange = None, h = None, alpha = None, dump = None, nmaxiter = None,
            alphaeps = None, feps = None,
            optdata = None):
    eta = 0.381966011   # (sqrt(5) - 1) / (sqrt(5) + 1)
    dr2 = np.inner(sdir, sdir)
    dr  = sqrt(dr2)
    x = [0.0, 0.0 + xrange*eta, xrange - xrange*eta, xrange]
    f0 = func(x0, optdata)
    f3 = func(x0 + xrange * sdir, optdata)
    fmin = f0
    xmin = 0.0
    if f3 < f0:
        fmin = f3
        xmin = xrange
    for i in range(nmaxiter):
        alpha2 = (x[1] + x[2]) / 2.0
        dx = alpha2 * sdir
        if abs(x[3] - x[0]) < alphaeps:
#        if abs(x[0] - x[1]) < alphaeps:
#            return i, x0 + dx, dx, alpha2
            x2 = x0 + dx
            f = func(x2, optdata)
            if f < fmin:
                print("f<fmin:", f, fmin)
                return i, x2, dx, alpha2
            else:
                print("f>fmin:", f, fmin, xmin)
                dx = xmin * sdir
#                f = func(x0+dx, optdata)
#                print("  f=", f)
                return i, x0 + dx, dx, xmin

        f1 = func(x0 + x[1] * sdir, optdata)
        f2 = func(x0 + x[2] * sdir, optdata)
        if f0 < f1 and f0 < f2 and f0 < f3:
            dxe = x[2] - x[0]
            x = [x[0], x[0] + dxe * eta, x[2] - dxe * eta, x[2]]
            f3 = f2
            fmin = f0
            xmin = 0.0
        elif f1 < f2:
            dxe = x[2] - x[0]
            x = [x[0], x[0] + dxe * eta, x[2] - dxe * eta, x[2]]
            f3 = f2
            fmin = f1
            xmin = x[1]
        else: # f1 > f2
            dxe = x[3] - x[1]
            x = [x[1], x[1] + dxe * eta, x[3] - dxe * eta, x[3]]
            f0 = f1
            fmin = f2
            xmin = x[2]
        xstr = joinf(x, "%12.6g", ", ")
        lvlprint(optdata.print_level, 4, "  lsgolden:{}".format(i), 
                "  xp=(", xstr, ")  fmin={:14.10g}".format(fmin))
        """
        f = []
        f.append(func(x0 + x[0] * sdir, optdata))
        f.append(func(x0 + x[1] * sdir, optdata))
        f.append(func(x0 + x[2] * sdir, optdata))
        f.append(func(x0 + x[3] * sdir, optdata))
        print("i=", i, "  x=", x, "  f=", f)
        """
    return -1, x0 + dx, dx, alpha2
    
def linesearch_exact(func, x0, sdir,
            xrange = None, h = None, alpha = None, dump = None, nmaxiter = None,
            alphaeps = None, feps = None,
            optdata = None):
    dr2 = np.inner(sdir, sdir)
    dr  = sqrt(dr2)
    f0   = func(x0, optdata)
    alphamin = 0.0
    xmin = x0.copy()
    fmin = f0
    imin = -2
    for i in range(nmaxiter):
        fm = func(x0 + h * sdir / dr, optdata)
        fp = func(x0 - h * sdir / dr, optdata)
        f1 = (fp - fm) / 2.0 / h
        f2 = (fp - 2.0 * f0 + fm) / h / h
        alpha2 = f1 / (abs(f2) + dump)
#        print("f1 = ", f1, "  f2 = ", f2, "  dr = ", dr)
        dx0 = alpha2 * sdir / dr
        x0 += dx0
        f0 = func(x0, optdata)
        dr0 = sqrt(np.inner(dx0, dx0))
        xstr = joinf(x0, "%12.6g", ", ")
        lvlprint(optdata.print_level, 4, "  lsexact:{}".format(i), 
                "  alpha2={:12.6g}  x=({})".format(alpha2, xstr), 
                "  f0={:12.6g}  fmin={:12.6g}".format(f0, fmin))
        if f0 <= fmin:
            alphamin = alpha2
            xmin = x0.copy()
            fmin = f0
            imin = i
        if abs(alpha2/dr) < alphaeps:
            return imin, xmin, alphamin * sdir, alphamin
#            return i, x0, alpha2 * sdir, alpha2
    return -1, xmin, alphamin * sdir, alphamin
#    return -1, x0, alpha2 * sdir, alpha2


def linesearch_one(func, x0, sdir,
            xrange = None, h = None, alpha = None, dump = None, nmaxiter = None,
            alphaeps = None, feps = None,
            optdata = None):
    return 1, x0 + alpha * sdir, alpha * sdir, alpha

def linesearch_simple(func, x0, sdir,
            xrange = None, h = None, alpha = None, dump = None, nmaxiter = None,
            alphaeps = None, feps = None,
            optdata = None):
    diff = -sdir
    f0 = func(x0, optdata)
#    print("x0=", x0, "   f0=", f0)
    xprev = x0.copy()
    x = x0.copy()
    for i in range(1, nmaxiter+1):
        alpha2 = i * alpha
        x  = x0 - alpha2 * diff
        f1 = func(x, optdata)
        
        xstr = joinf(x, "%12.6g", ", ")
        lvlprint(optdata.print_level, 4, "  ls_simple:{}".format(i), "  x=(", xstr, ")  f1=", f1)
        if f1 > f0:
            return i, xprev, -alpha2 * diff, alpha2
        else:
            xprev = x.copy()
            f0 = f1
    return -1, x, -alpha2 * diff, alpha2

def linesearch_none(func, x0, sdir,
            xrange = None, h = None, alpha = None, dump = None, nmaxiter = None,
            alphaeps = None, feps = None,
            optdata = None):
    return 1, x0 + sdir, sdir, 1.0

