import csv
import sys
import numpy as np
from numpy import sin, cos, tan, pi
from pprint import pprint

from tklib.tkobject import tkObject
from tklib.tkparams import tkParams
from tklib.tksci.tkoptimizeobject import tkOptimizeData, ReduceOptParams, RecoverParams
from tklib.tksci.tkoptimize_simplex import simplex
from tklib.tksci.tkoptimize_gradient import gradient
from tklib.tksci.tkoptimizeobject import searchdir_sd, searchdir_cg, searchdir_newton, diff2array_newton
from tklib.tksci.tkoptimizeobject import searchdir_dfp, diff2array_dfp, searchdir_dfpH, diff2array_dfpH
from tklib.tksci.tkoptimizeobject import searchdir_bfgs, diff2array_bfgs, searchdir_bfgsH, diff2array_bfgsH
from tklib.tksci.tkoptimizeobject import searchdir_broyden, diff2array_broyden, searchdir_broydenH, diff2array_broydenH
from tklib.tksci.tkoptimizeobject import linesearch_none, linesearch_one, linesearch_simple
from tklib.tksci.tkoptimizeobject import linesearch_exact
from tklib.tksci.tkoptimizeobject import linesearch_golden,linesearch_armijo
from tklib.tkutils import pint, pfloat, safe_getelement, check_attributes, merge_attributes, joinf, save_csv
from tklib.tkutils import print_data, format_strlist
from tklib.tkinifile import tkIniFile


def mlsq(x, y, m, iPrint = 0, print_level = None):
        """
            LSQ for m-th order polynomial
            x, y: data to be fitted
            m: order of polynomial
            iPrint: output control [0|1]
        """
        
        if print_level is not None: iPrint = print_level

        n = len(x)
        Si  = np.empty([m+1, 1])
        Sij = np.empty([m+1, m+1])

        for l in range(0, m+1):
             Si[l, 0] = sum([y[i] * pow(x[i], l) for i in range(n)])

        for j in range(0, m+1):
            for l in range(j, m+1):
                v = sum([pow(x[i], j+l) for i in range(n)])
                Sij[j, l] = Sij[l, j] = v

        if iPrint == 1:
            print("tkoptimize.mslq:: Vector and Matrix:") 
            print("Si=")
            pprint(Si)
            print("Sij=")
            pprint(Sij)
            print("")

        ci = np.linalg.inv(Sij) @ Si
        ci = ci.transpose().tolist()

        return ci[0], Si, Sij

def mlsq_general(x, y, m, lsqfunc, iPrint = 0, print_level = None):
        """
            LSQ for linear combinationof m functions
            x, y: data to be fitted
            m: number of functions
            lsqfunc(i, x): functions
            iPrint: output control [0|1]
        """

        if print_level is not None: iPrint = print_level

        n = len(x)
        Si  = np.empty([m, 1])
        Sij = np.empty([m, m])
#        print("n=", n, "  m=", m)

        for l in range(m):
            Si[l, 0] = sum([y[i] * lsqfunc(l, x[i]) for i in range(n)])

        for j in range(m):
            for l in range(j, m):
                v = sum([lsqfunc(j, x[i]) * lsqfunc(l, x[i]) for i in range(n)])
                Sij[j, l] = Sij[l, j] = v

        if iPrint == 1:
            print("tkoptimize.mslq_general:: Vector and Matrix:") 
            print("Si=")
            pprint(Si)
            print("Sij=")
            pprint(Sij)
            print("")

        ci = np.linalg.inv(Sij) @ Si
        ci = ci.transpose().tolist()

        return ci[0], Si, Sij

def mlsq_general_optid(xlist, ylist, ai, linid, lsqfunc, print_level = 1):
    if print_level:
        print()
        print("mlsq_general_optid:")

    ndata = len(xlist)
    nvars = len(linid)
    ci_list = []
    nmatrix = 0
    for idx, id in enumerate(linid):
        if id: 
            ci_list.append(nmatrix)
            nmatrix += 1
        else:
            ci_list.append('-')
    
    if print_level:
        print("  nvars       =", nvars)
        print("  nmatrix     =", nmatrix)
        print("  ai          =", format_strlist(ai,    "{:8.3g}", separator = ' '))
        print("  linid       =", format_strlist(linid, "{:8d}", separator = ' '))
        print("  matrix index=", format_strlist(ci_list, "{:>8}", separator = ' '))
        print("  nData       =", ndata)
    
    ylist_conv = ylist.copy()
    for idata, x in enumerate(xlist):
        for idx, id in enumerate(linid):
            if id == 0:
                ylist_conv[idx] -= ai[idx] * lsqfunc(idx, x)

    if print_level >= 2:
        print("  xlist     =", format_strlist(xlist, "{:6.3g}", separator = ' '))
        print("  ylist     =", format_strlist(ylist, "{:6.3g}", separator = ' '))
        print("  ylist_conv=", format_strlist(ylist_conv, "{:6.3g}", separator = ' '))
 
    Si = np.zeros([nmatrix, 1])
    for idata, x in enumerate(xlist):
        ci = 0
        for ivar in range(nvars):
            if linid[ivar] == 0: continue

#            print("Si: ci=", ci, ivar)
            Si[ci, 0] += ylist_conv[idata] * lsqfunc(ivar, x)
            ci += 1

    Sij = np.zeros([nmatrix, nmatrix])
    for ivar in range(nvars):
        if linid[ivar] == 0: continue

        ci = ci_list[ivar]
        for jvar in range(ivar, nvars):
            if linid[jvar] == 0: continue

            cj = ci_list[jvar]
            for idata, x in enumerate(xlist):
#                print("Sij: ci,cj=", ci, cj, ivar, jvar)
                Sij[ci, cj] += lsqfunc(ivar, x) * lsqfunc(jvar, x)

            Sij[cj, ci] = Sij[ci, cj]

    if print_level == 1:
        print("optimiz_mup.mslq_general3:: Vector and Matrix:") 
        print("  Si =", format_strlist(Si[:, 0], "{:10.4g}"))
        for irow, row in enumerate(Sij):
            if irow == 0:
                print("  Sij=", format_strlist(row, "{:10.4g}"))
            else:
                print("      ", format_strlist(row, "{:10.4g}"))

    ai_new = np.linalg.inv(Sij) @ Si
    ai_new = ai_new.transpose().tolist()[0]

    if print_level == 1:
        print("  ai(new)=", format_strlist(ai_new, "{:10.4g}"))
        print("  ai(org)=", format_strlist(ai, "{:10.4g}"))
    
    ai_all = ai.copy()
    ci = 0
    for ivar, id in enumerate(linid):
        if id == 1:
            ai_all[ivar] = ai_new[ci]
            ci += 1

    if print_level == 1:
        print("  ai_all=", format_strlist(ai_all, "{:10.4g}"))

    return ai_new, ai_all, Si, Sij

class tkOptimize(tkObject):
    """ 
    tklib Optimize class
    """

    def __init__(self, **args):
        self.datapath = None
        self.otherparameters = {}
        self.varname = []
        self.x0      = []
        self.dx      = []
        self.optid   = []
        self.method   = 'cg'
        self.nmaxiter = 200
        self.tolx     = 1.0e-5
        self.tolf     = 1.0e-5
        self.dump     = 0.3
        self.lsmode         = 'armijo'
        self.ls_func        = None
        self.searchdir_func = None
        self.ls_nmaxiter    = 100
        self.ls_alphaeps    = 1.0e-5
        self.ls_feps        = 1.0e-5
        self.ls_alpha       = 1.0
        self.ls_h           = 0.0e-3
        self.ls_xrange      = 0.5
        self.ls_dump        = 0.3
        self.print_level    = 4
        self.iprintinterval = 10
        self.callback       = None
        self.func           = None
        self.ycalfunc       = None
        self.diff1func      = None
        self.diff2func      = None
        self.diff2arrayfunc = None
        self.optdata        = tkParams()
        self.update(**args)

#    def __del__(self):
#        print("{} destroyed".format(self.name))

    def __str__(self):
        return "optimization object by {}".format(self.method)


    def initialize(self, **args):
        self.update(**args)

    def set_datapath(self, path):
        self.datapath = path
    
    def set_method(self, algorism, lsmode = None):
        self.method = algorism
        self.lsmode  = lsmode

    def ycal(self, xlist = None, params = None, optdata = None):
        """
            ycal: calculate y value using self.fitfunc
            Input:
                xlist : float list of variables
                params: float list of optimization parameters
            Return:
                y value
        """

        if params is None:
            params = self.x
        return self.ycalfunc(xlist, params, optdata)

    def ycal_list(self, xarraylist, params = None, optdata = None):
        y = []
        for i in range(len(xarraylist[0])):
            xlist = [xarraylist[j][i] for j in range(len(xarraylist))]
            y.append(self.ycal(xlist, params, optdata))
        return y
   
    def save_csv(self, path, headerlist, datalist, is_print = 0):
        return save_csv(self, path, headerlist, datalist, is_print)

    def read_parameters(self, path, update_params = 0, AddSection = 0):
        ini = tkIniFile(path, OpenFile = 0)
        inf = ini.ReadAll(path, AddSection = AddSection)
        if inf is None:
            return None

        if update_params:
            infile   = safe_getelement(inf, 'datapath', self.datapath)
            algorism = safe_getelement(inf, 'method', self.method)
            lsmode   = safe_getelement(inf, 'lsmode', self.lsmode)
            nmaxiter       = pint  (safe_getelement(inf, 'nmaxiter'),       self.nmaxiter)
            tolx           = pfloat(safe_getelement(inf, 'tolx'),           self.tolx)
            tolf           = pfloat(safe_getelement(inf, 'tolf'),           self.tolf)
            dump           = pfloat(safe_getelement(inf, 'dump'),           self.dump)
            ls_xrange      = pfloat(safe_getelement(inf, 'ls_xrange'),      self.ls_xrange)
            ls_h           = pfloat(safe_getelement(inf, 'ls_h'),           self.ls_h)
            ls_alpha       = pfloat(safe_getelement(inf, 'ls_alpha'),       self.ls_alpha)
            ls_dump        = pfloat(safe_getelement(inf, 'ls_dump'),        self.ls_dump)
            ls_nmaxiter    = pint  (safe_getelement(inf, 'ls_nmaxiter'),    self.ls_nmaxiter)
            ls_alphaeps    = pfloat(safe_getelement(inf, 'ls_alphaeps'),    self.ls_alphaeps)
            print_level    = pfloat(safe_getelement(inf, 'print_level'),    self.print_level)
            iprintinterval = pint  (safe_getelement(inf, 'iprintinterval'), self.iprintinterval)

            self.set_method(algorism, lsmode)
            self.initialize(
                nmaxiter = nmaxiter, tolx = tolx, tolf = tolf,
                print_level = print_level, iprintinterval = iprintinterval,
                dump = dump,
                ls_xrange = ls_xrange, ls_h = ls_h, ls_alpha = ls_alpha, ls_dump = dump, 
                ls_nmaxiter = ls_nmaxiter, ls_alphaeps = ls_alphaeps
                )

            for key in inf.keys():
                vals = inf[key].split(':')
                if len(vals) < 3:
                    continue
                try:
                    self.update_parameter(key, vals[0], dx = vals[1], optid = vals[2])
                except:
                    pass

#            print("******check")                
#            self.print_parameters()

        return inf

    def save_parameters(self, path, update_params = 0):
        ini = tkIniFile(path, OpenFile = 0)
        if not ini:
            return 0

        if update_params:
            try:
                self.x0 = self.xmin
            except:
                pass

        for key in ['datapath', 'method', 'nmaxiter', 'tolx', 'tolf',
                    'dump',
                    'lsmode', 'ls_xrange', 'ls_h', 'ls_alpha', 'ls_dump', 
                    'ls_nmaxiter', 'ls_alphaeps',
                    'print_level', 'iprintinterval']:
            try:
                val = "{}".format(getattr(self, key))
            except:
                val = ''
            ini.WriteString("Configuration", key, val)

        for key in self.otherparameters.keys():
            ini.WriteString("OtherParameters", key, self.otherparameters[key])

        ini.WriteString("Parameters", "nparams", len(self.varname))
        for i in range(len(self.varname)):
            name  = self.varname[i]
            val   = self.x0[i]
            scale = self.dx[i]
            id    = self.optid[i]
#            print("{}: {}: {}: {}".format(name, val, scale, id))
            ini.WriteString("Parameters", name, "{}:{}:{}".format(val, scale, id))

        try:
            fmin = self.fmin
        except:
            fmin = 1.0e10
        ini.WriteString("OtherParameters", 'fmin', fmin)

        return 1

    def print_parameters(self, varname = None, x0 = None, dx = None, optid = None, f = None):
        if varname is None:
            varname = self.varname
        if x0 is None:
            x0 = self.x0
        if dx is None:
            dx = self.dx
        if optid is None:
            optid = self.optid
        
        for i in range(len(x0)):
            print("   {:>10}={:12.6g} (id={}) (dx={:8.3g})".format(varname[i], x0[i], optid[i], dx[i]))
        if f is not None:
            print("  f=", f)

    def make_optdata(self, **args):
        self.optdata = tkOptimizeData(self.x0, 
                method = self.method, lsmode = self.lsmode,
                fitfunc = self.fitfunc, func = self.func, 
                diff1func = self.diff1func, diff2func = self.diff2func,
                diff2arrayfunc = self.diff2arrayfunc,
                dump = self.dump,
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                **args)
        return self.optdata

    def get_parameter(self, varname):
        for i in range(len(self.varname)):
            if varname == self.varname[i]:
                return self.x0[i]
        for key in self.otherparameters.keys():
            if varname == key:
                return self.otherparameters[key]

        return None

    def update_parameter(self, varname, val, dx = None, optid = None):
        for i in range(len(self.varname)):
            if varname == self.varname[i]:
                self.x0[i] = pfloat(val, None)
                if dx is not None:
                    self.dx[i] = pfloat(dx, None)
                if optid is not None:
                    self.optid[i] = pint(optid, None)

#                print("updated:", self.x0[i], self.dx[i], self.optid[i])
                
                return 1
            
        if varname in self.otherparameters.keys():
            self.otherparameters[varname] = val
            return 1

        return 0

    def add_otherparameter(self, varname, val):
        self.otherparameters[varname] = val

    def add_parameter(self, varname, val, scale, optid, **args):
        """
            add_parameter: Add optimization parameter with name, value, scale and id
            Input:
                varname: name of the optimization parameter
                val    : value
                scale  : parameter range to build initial simplex for the simplex method
                id     : optimization flag. 0: fixed, 1: optimized 
        """

        idx = -1
        for i in range(len(self.varname)):
            if self.varname[i] == varname:
                idx = i
                break
        if idx >= 0:
            self.x0[idx]    = val
            self.dx[idx]    = scale
            self.optid[idx] = optid
        else:
            self.varname.append(varname)
            self.x0.append(val)
            self.dx.append(scale)
            self.optid.append(optid)

        self.update(**args)

    def set_functions(self, fitfunc = None, func = None, diff1func = None, 
                    diff2func = None, diff2arrayfunc = None, **args):
        """
            set_functions: Set function pointers to self
            Input:
                fitfunc(x:list, params:list)
                        Function to fit
                func(params:list, optdata)
                        Target function
                diff1func(i:int, params:list, optdata)
                        First derivative of target function w.r.t. params[i]
                diff2func(i:int, j:int, params:list, optdata)
                        Second derivative of target function w.r.t. params[i] and params[j]
                diff2arrayfunc(params:list, optdata)
                        Hessian matrix of target function
        """

        self.fitfunc        = fitfunc
        self.func           = func
        self.diff1func      = diff1func
        self.diff2func      = diff2func
        self.diff2arrayfunc = diff2arrayfunc
        self.update(**args)

    def set_method(self, method = None, lsmode = None, 
                    searchdir_func = None, ls_func = None, initial_simplex = None, **args):
        if searchdir_func is None:
            searchdir_func = searchdir_sd
            if method == 'newton':
                searchdir_func = searchdir_newton
            elif method == 'sd':
                searchdir_func = searchdir_sd
            elif method == 'cg':
                searchdir_func = searchdir_cg
            elif method == 'broyden':
                searchdir_func = searchdir_broydenH
            elif method == 'broydenB':
                searchdir_func = searchdir_broyden
            elif method == 'dfp':
                searchdir_func = searchdir_dfpH
            elif method == 'dfpB':
                searchdir_func = searchdir_dfp
            elif method == 'bfgs':
                searchdir_func = searchdir_bfgsH
            elif method == 'bfgsB':
                searchdir_func = searchdir_bfgs

        if ls_func is None:
            if method == 'newton':
                ls_func = linesearch_none
            else:
                ls_func = linesearch_one

            if lsmode == 'newton' or lsmode == 'none':
                ls_func = linesearch_none
            elif lsmode == 'simple':
                ls_func = linesearch_simple
            elif lsmode == 'exact':
                ls_func = linesearch_exact
            elif lsmode == 'golden':
                ls_func = linesearch_golden
            elif lsmode == 'armijo':
                ls_func = linesearch_armijo

        self.method          = method
        self.initial_simplex = initial_simplex
        self.lsmode          = lsmode
        self.searchdir_func  = searchdir_func
        self.ls_func         = ls_func
        self.update(**args)


#    def ReduceOptParams(self, allvars, optid):
#        return ReduceOptParams(allvars, optid)

#    def RecoverParams(self, optvars, allvars, optid):
#        return RecoverParams(optvars, allvars, optid)

    def optfunc(self, x, optdata = None):
        x_all = RecoverParams(x, self.x0_all, self.optid)
#        x_all = self.RecoverParams(x, self.x0_all, self.optid)
        return self.func(x_all, optdata)

    def optdiff1func(self, i, x, optdata = None):
        if self.diff1func is None: return None

        x_all = RecoverParams(x, self.x0_all, self.optid)
#        x_all = self.RecoverParams(x, self.x0_all, self.optid)
        return self.diff1func(i, x_all, optdata)

    def optdiff2func(self, i, j, x, optdata = None):
        if self.diff2func is None:  
            return None

        x_all = RecoverParams(x, self.x0_all, self.optid)
#        x_all = self.RecoverParams(x, self.x0_all, self.optid)
        return self.diff2func(i, j, x_all, optdata)

    def optdiff2arrayfunc(self, x, opdata = None):
        if self.diff2arrayfunc is None:  
            return None

        x_all = RecoverParams(x, self.x0_all, self.optid)
#        x_all = self.RecoverParams(x, self.x0_all, self.optid)
        return self.diff2arrayfunc(x_all, optdata)

    def optimize(self, **args):
        self.update(**args)

        if self.method == 'mlsq':
            ci = self.mlsq(**args)
            return ci, -1.0, self

        if self.method == 'mlsq_general':
            ci = self.mlsq_general(**args)
            return ci, -1.0, self

        self.x0_all                  = self.x0
        self.optdata.x0_all          = self.x0_all
        self.optdata.optid           = self.optid
        self.optdata.callback        = self.callback
        self.optdata.RecoverParams   = RecoverParams
        self.optdata.ReduceOptParams = ReduceOptParams
#        self.optdata.RecoverParams   = self.RecoverParams
#        self.optdata.ReduceOptParams = self.ReduceOptParams
        if type(self.diff1func) is str and (self.diff1func == '2-points' or self.diff1func == '3-points'):
            self.diff1func = self.optdiff1func

        x0_original = self.x0
        self.x0 = ReduceOptParams(self.x0, self.optid)
#        self.x0 = self.ReduceOptParams(self.x0, self.optid)
        if self.method == 'simplex':
            ci, S2, optdata = simplex(
                func = self.optfunc, x0 = self.x0, dx = self.dx, initial_simplex = self.initial_simplex,
                itmax = self.nmaxiter, tolx = self.tolx, tolf = self.tolf, 
                optdata = self.optdata, callback = self.callback, 
                lprint = self.print_level, iprintinterval = self.iprintinterval)
            ci = RecoverParams(ci, self.x0_all, self.optid)
#            ci = self.RecoverParams(ci, self.x0_all, self.optid)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'marquart':
            ci, S2, optdata = self.marquart(
                func = self.optfunc, diff1func = self.optdiff1func,
                x0 = self.x0, xsource = self.xsource, ysource = self.ysource, 
                optdata = self.optdata, callback = self.callback, 
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                dump = self.dump,
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                **args)
            ci = RecoverParams(ci, self.x0_all, self.optid)
#            ci = self.RecoverParams(ci, self.x0_all, self.optid)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'sd':
            print("method: sd")
            print("searchdir_func:", searchdir_sd)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.optfunc, diff1func = self.optdiff1func, 
                searchdir_func = self.searchdir_func, ls_func = self.ls_func,
                x0 = self.x0, optdata = self.optdata, callback = self.callback, 
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            ci = RecoverParams(ci, self.x0_all, self.optid)
#            ci = self.RecoverParams(ci, self.x0_all, self.optid)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'cg':
            print("method: cg")
            print("searchdir_func:", searchdir_cg)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("tolx: ", self.tolx)
            print("tolf: ", self.tolf)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.optfunc, diff1func = self.optdiff1func, 
                searchdir_func = self.searchdir_func, ls_func = self.ls_func,
                x0 = self.x0, optdata = self.optdata, callback = self.callback, 
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            ci = RecoverParams(ci, self.x0_all, self.optid)
#            ci = self.RecoverParams(ci, self.x0_all, self.optid)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'newton':
            print("method: newton")
            print("searchdir_func:", searchdir_newton)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2func:", self.diff2func)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.optfunc, diff1func = self.optdiff1func, diff2func = self.optdiff2func,
                searchdir_func = searchdir_newton, ls_func = self.ls_func,
                x0 = self.x0, optdata = self.optdata, callback = self.callback, 
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            ci = RecoverParams(ci, self.x0_all, self.optid)
#            ci = self.RecoverParams(ci, self.x0_all, self.optid)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'broydenB':
            self.diff2arrayfunc = diff2array_broyden
            print("method: broydenB")
            print("searchdir_func:", searchdir_broyden)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2arrayfunc:", self.diff2arrayfunc)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.optfunc, diff1func = self.optdiff1func,
                searchdir_func = searchdir_broyden, ls_func = self.ls_func,
                x0 = self.x0, optdata = self.optdata, callback = self.callback, 
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            ci = RecoverParams(ci, self.x0_all, self.optid)
#            ci = self.RecoverParams(ci, self.x0_all, self.optid)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'broyden':
            print("method: broyden")
            print("searchdir_func:", searchdir_broydenH)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2arrayfunc:", diff2array_broydenH)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.func, x0 = self.x0, 
                diff1func = self.diff1func,
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                optdata = self.optdata, callback = self.callback, 
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                searchdir_func = searchdir_broydenH, ls_func = self.ls_func,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'dfpB':
            print("method: dfpB")
            print("searchdir_func:", searchdir_dfp)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2arrayfunc:", diff2array_dfp)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.func, x0 = self.x0, 
                diff1func = self.diff1func,
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                optdata = self.optdata, callback = self.callback, 
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                searchdir_func = searchdir_dfp, ls_func = self.ls_func,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'dfp':
            print("method: dfp")
            print("searchdir_func:", searchdir_dfpH)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2arrayfunc:", diff2array_dfpH)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.func, x0 = self.x0, 
                diff1func = self.diff1func,
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                optdata = self.optdata, callback = self.callback, 
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                searchdir_func = searchdir_dfpH, ls_func = self.ls_func,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'bfgs':
            print("method: bfgs")
            print("searchdir_func:", searchdir_bfgsH)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2arrayfunc:", diff2array_bfgsH)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.func, x0 = self.x0, 
                diff1func = self.diff1func,
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                optdata = self.optdata, callback = self.callback, 
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                searchdir_func = searchdir_bfgsH, ls_func = self.ls_func,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        if self.method == 'bfgsB':
            print("method: bfgsB")
            print("searchdir_func:", searchdir_bfgs)
            print("ls_func:", self.ls_func)
            print("diff1func:", self.diff1func)
            print("diff2arrayfunc:", diff2array_bfgs)
            print("")
#            exit()
            ci, S2, optdata = gradient(
                func = self.func, x0 = self.x0, 
                diff1func = self.diff1func,
                nmaxiter = self.nmaxiter, tolx = self.tolx, tolf = self.tolf,
                optdata = self.optdata, callback = self.callback, 
                print_level = self.print_level, iprintinterval = self.iprintinterval,
                searchdir_func = searchdir_bfgs, ls_func = self.ls_func,
                ls_xrange = self.ls_xrange, ls_h = self.ls_h, ls_alpha = self.ls_alpha, 
                ls_dump = self.ls_dump, ls_nmaxiter = self.ls_nmaxiter, 
                ls_alphaeps = self.ls_alphaeps, ls_feps = self.ls_feps)
            self.x0 = x0_original
            self.xmin = ci
            self.fmin = S2
            return ci, S2, optdata

        print("")
        print("Error in tkoptimize.optimize:: Invalid method [{}]".format(self.method))

        return [], -1.0, self.optdata


    def marquart_S2(self, x, y, c, func):
        """
            Objective function for Marquart method
            x[], y[]: data to be fitted
            c[]: parameters to be optimized
            func(i, x, c): y[i] = func(x[i], c) to be fitted
        """
        sum = 0.0
        for k in range(len(x)):
            d = func(x[k], c) - y[k]
            sum += d * d
        return sum


    def marquart(self, dump = 0.3, nmaxiter = 100, tolx = 1.0e-5, tolf = 1.0e-5,
                callback = None, print_level = 1, iprintinterval = 10, 
                optdata = None,
                **kwargs):
        """
            Non-linear LSQ by Marquart method
            args: 'func', 'xsource', 'ysource', 'diff1func', 
                  'nmaxiter', 'tolx', 'tolf', 'dump', 'callback', 
                  'print_level', 'iprintinterval'
            xd[], yd[]: data to be fitted
            c[]: parameters to be optimized
            y[i] = func(i, xd, c): funcion to be fitted
            diff1func(i, xd): dy / dx[i]
            iPrint: output control [0|1]
        """

        self.dump           = dump
        self.nmaxiter       = nmaxiter
        self.tolx           = tolx
        self.tolf           = tolf
        self.callback       = callback
        self.print_level    = print_level
        self.iprintinterval = iprintinterval
        self.update(**kwargs)
        if check_attributes(self, 1, 'func', 'xsource', 'ysource', 'diff1func'):
            self.status = -1
            return None, None, self
        merge_attributes(self, optdata)

        xd             = self.xsource
        yd             = self.ysource
        c0             = self.x0.copy()
        n              = len(c0)
        ndata          = len(xd)
        func           = self.func
        diff1func      = self.diff1func
        iPrint         = self.print_level
# calculate initial parameters
        sum = None
        if print_level == 1:
            sum = self.marquart_S2(xd, yd, c0, func)
            print("c0 = ({}, {}): f = {}".format(c0[0], c0[1], sum))

        fprev = sum
        Afi   = np.empty([n, 1])
        AAij  = np.empty([n, n])
# optimization start
        for iter in range(nmaxiter):
            self.iter = iter
            for i in range(0, n):
                Afi[i, 0] = 0.0
                for k in range(ndata):
                   Afi[i, 0] += diff1func(i, xd[k], c0, self) * (func(xd[k], c0) - yd[k])
                for j in range(i, n):
                    AAij[i, j] = 0.0
                    for k in range(ndata):
                        AAij[i, j] += diff1func(i, xd[k], c0, self) * diff1func(j, xd[k], c0)
                    AAij[j, i] = AAij[i, j]

#        print("At*f=", Afi)
#        print("At*A=", AAij)
#        print("")
            tr = np.trace(AAij)
            AAij += dump * tr
#        AAij = 1000 * np.eye(n)
#        dc = -np.linalg.inv(AAij) @ Afi
            dc = -np.linalg.solve(AAij, Afi)
            dc = dc.transpose()[0]
            c0prev = c0
            c0 = c0 + dc
            dcmax = max(abs(dc))
            sum = self.marquart_S2(xd, yd, c0, func)

            if callback is not None:
                self.islinesearch = 0
                self.xprev  = c0prev
                self.x      = c0
                self.f      = sum
                self.dx     = dc
                self.dxmax  = dcmax
                self.status = 1
                if callback(self) == 0:
                    break

            if print_level == 1 and iter % iprintinterval ==0:
                print("   dc=", dc)
                print("{:4d}: c = ({:12.6g}, {:12.6g}) dcmax = {:12.6g}  S = {:12.6g}".format(
                        iter, c0[0], c0[1], dcmax, sum))

            if dcmax < tolx or (fprev is not None and abs(sum - fprev) < tolf):
                self.ci = c0
                self.S2 = sum
                self.dcmax = dcmax
                self.iter = iter
                self.status     = 1
                self.status_str = 'Converged'
                return c0, sum, self

            fprev = sum

        self.status     = 0
        self.status_str = 'Not cconverged'
        return [], -1.0, self


    def mlsq(self, **args):
        """
            LSQ for m-th order polynomial
            x, y: data to be fitted
            m: order of polynomial
            iPrint: output control [0|1]
        """
        self.update(**args)
        x = self.xsource
        y = self.ysource
        m = self.norder
        iPrint  = self.iprint

        n = len(x)
        Si  = np.empty([m+1, 1])
        Sij = np.empty([m+1, m+1])

        for l in range(0, m+1):
             Si[l, 0] = sum([y[i] * pow(x[i], l) for i in range(n)])

        for j in range(0, m+1):
            for l in range(j, m+1):
                v = sum([pow(x[i], j+l) for i in range(n)])
                Sij[j, l] = Sij[l, j] = v

        if iPrint == 1:
            print("tkoptimize.mslq:: Vector and Matrix:") 
            print("Si=")
            pprint(Si)
            print("Sij=")
            pprint(Sij)
            print("")

        self.Si  = Si
        self.Sij = Sij
        ci = np.linalg.inv(Sij) @ Si
        ci = ci.transpose().tolist()
        self.ci  = ci[0]

        return ci[0]


    def mlsq_general(self, **args):
        """
            LSQ for linear combinationof m functions
            x, y: data to be fitted
            m: number of functions
            lsqfunc(i, x): functions
            iPrint: output control [0|1]
        """
        self.update(**args)
        x = self.xsource
        y = self.ysource
        m = self.norder
        lsqfunc = self.lsqfunc
        iPrint  = self.iprint

        n = len(x)
        Si  = np.empty([m, 1])
        Sij = np.empty([m, m])
        print("n=", n, "  m=", m)

        for l in range(m):
            Si[l, 0] = sum([y[i] * lsqfunc(l, x[i]) for i in range(n)])

        for j in range(m):
            for l in range(j, m):
                v = sum([lsqfunc(j, x[i]) * lsqfunc(l, x[i]) for i in range(n)])
                Sij[j, l] = Sij[l, j] = v

        if iPrint == 1:
            print("tkoptimize.mslq_general:: Vector and Matrix:") 
            print("Si=")
            pprint(Si)
            print("Sij=")
            pprint(Sij)
            print("")

        ci = np.linalg.inv(Sij) @ Si
        ci = ci.transpose().tolist()
        return ci[0]


def main():
    a = Optimize(method = 'mslq', name = 'test')
    print("name =", a.name())
    print("method =", a.method())


if __name__ == "__main__":
    main()

