import sys
import numpy as np
from numpy import sin, cos, tan, pi
from pprint import pprint


from tklib.tkobject import tkObject



"""
    def callback(obj)
"""


def get_zeros(func, xmin = 0.0, xmax = 10.0, dx = 0.1, eps = 1.0e-10, nmaxiter = 50, h = 1.0e-10, dump = 1.0, print_level = 0):
    """
    関数 func(x)の零点を計算する
    """

    zeros = []
    nnode = 0
    prev_r = 0.0
    prev_f = 1.0
    for r in np.arange(xmin, xmax + dx, dx):
        f = func(r)

        if f == 0.0:
            zeros.append(r)
        elif r > 0.0 and f * prev_f < 0.0:
            r0 = prev_r
            r1 = r
            for i in range(nmaxiter):
                if print_level:
                    print(f" iter#{i}/{nmaxiter}: nnode={nnode} dr={r1:8.4g} - {r0:8.4g} = {r1-r0:8.4g}")
                if abs(r1 - r0) < eps: 
                    zeros.append(r1)
                    break

                r0 = r1
                f1 = func(r1)

                rh = r1 + h
                fh = func(rh)
                fdiff = (fh - f1) / (rh - r1)

                r1 = r1 - f1 / fdiff / dump

            nnode += 1

        prev_r = r
        prev_f = f

    return zeros


class Equation(tkObject):
    """ 
    tklib Equation Solver class
    """

    def __init__(self, **args):
        self.method    = None
        self.func      = None
        self.diff1func = None
        self.x0        = None
        self.xa        = None
        self.xb        = None
        self.callback  = None
        self.kmix      = 0.3
        self.dump      = 0.3
        self.nmaxiter  = 100
        self.eps       = 1.0e-5
        self.delta     = 1.0e-5
        self.isprint   = 0
        self.update(**args)

#    def __del__(self):
#        print("{} destroyed".format(self.name))

    def __str__(self):
        return "equation object by {}".format(self.vars['method'])

    def solve(self, **args):
        self.update(**args)
        if self.method == 'scf':
            x, f, iter, dx = self.scf(**args)
            self.x    = x
            self.f    = f
            self.iter = iter
            self.dx   = dx
            return x
        if self.method == 'newton':
            x, f, iter, dx = self.newton1d(**args)
            self.x    = x
            self.f    = f
            self.iter = iter
            self.dx   = dx
            return x
        if self.method == 'bisection':
            x, f, iter, xrange = self.bisection(**args)
            self.x      = x
            self.f      = f
            self.iter   = iter
            self.xrange = xrange
            self.dx     = xrange[1] - xrange[0]
            return x
        if self.method == 'brent':
            x, f, iter, xrange = self.brent(**args)
            self.x      = x
            self.f      = f
            self.iter   = iter
            self.xrange = xrange
            self.dx     = xrange[1] - xrange[0]
            return x
        else:
            print("Error in tkequation.solve: Invalid method [{}]".format(self.method))
            exit()


    def brent(self, **args):
        self.update(**args)
        func      = self.func
        diff1func = self.diff1func
        xa        = self.xa
        xb        = self.xb
        dump      = self.dump
        nmaxiter  = self.nmaxiter
        eps       = self.eps
        delta     = self.delta
        callback  = self.callback
        IsPrint   = self.isprint

        fa = func(xa)
        fb = func(xb)
        if fa * fb > 0.0:
            if IsPrint == 1:
                print("tkaequation.breant:: Error: " \
                      "Initial xa and xb should be chosen as f(xa)*f(xb) < 0")
                print("     xa={:12.6g}  f(xa)={:12.6g}".format(xa, fa))
                print("     xb={:12.6g}  f(xb)={:12.6g}".format(xb, fb))
            return 1.0e10, 1.0e10, -2, [xa, xb]

        if abs(fa) < abs(fb):
            xa, xb = [xb, xa]
            fa, fb = [fb, fa]
        if IsPrint == 1:
            print("  xa = {:12.8f}  fa = {:12.4g}".format(xa, fa))
            print("  xb = {:12.8f}  fb = {:12.4g}".format(xb, fb))

        xc = xa
        mflag = 1

        for i in range(nmaxiter):
            if abs(xa - xb) < eps:
                xh = (xa + xb) / 2.0
                fh = func(xh)
                if IsPrint == 1:
                    print("  Success: Convergence reached at x = {}".format(xh))
                return xh, fh, i, [xa, xb]

            fc = func(xc)
            if abs(xa - xc) > 1.0e-10 and abs(xb - xc) > 1.0e-10:
#inverse quadratic interpolation
                xs = xa * fb * fc / (fa - fb) / (fa - fc) \
                   + xb * fa * fc / (fb - fa) / (fb - fc) \
                   + xc * fa * fb / (fc - fa) / (fc - fb)   
            else:
#secant method
                xs = xb - xb * (xb - xa) / (fb - fa)

            x4 = (3.0 * xa + xb) / 4.0
            if not (x4 < xs < xb or xb < xs < x4):
                mflag = 1
#            print("condition 1")
            elif mflag == 1 and abs(xs - xb) >= abs(xb - xc) / 2.0:
                mflag = 1
#            print("condition 2")
            elif mflag == 0 and abs(xs - xb) >= abs(xc - xd) / 2.0:
                mflag = 1
#            print("condition 3")
            elif mflag == 1 and abs(xb - xc) < delta:
                mflag = 1
#            print("condition 4")
            elif mflag == 0 and abs(xc - xd) < delta:
                mflag = 1
#            print("condition 5")
            else:
                mflag = 0
#            print("condition 6")

            if mflag == 1:
                xs = (xa + xb) / 2.0  # bisection

            fs = func(xs)
            xd = xc
            fd = fc
            xc = xb
#            fc = fb

            if fa * fs < 0.0:
#            print("condition a: EFa,b,c,s=", EFa, EFb, EFc, EFs)
                xb = xs
                fb = fs
            else:
#            print("condition b: EFa,b,c,s=", EFa, EFb, EFc, EFs)
                xa = xs
                fa = fs
            if abs(fa) < abs(fb):
                xa, xb = [xb, xa]
                fa, fb = [fb, fa]

            fs  = func(xs)

            self.iter  = i
            self.x      = xs
            self.f      = fs
            self.xa     = xa
            self.fa     = fa
            self.xb     = xb
            self.fb     = fb
            self.xhalf  = xs
            self.fhalf  = fs
            self.dx     = xb - xa
            self.mflag  = mflag

            if callback is not None:
                if callback(self) == 0:
                    break

            if IsPrint == 1:
                print("  Iter {}: xa,b = {:12.8f} - {:12.8f}  fa,b = {:12.4g} - {:12.8g}  mflag = {}"
                    .format(i, xa, xb, fa, fb, mflag))
                print("     xs = {:12.8f}  fs = {:12.4g}".format(xs, fs))

        else:
            if IsPrint == 1:
                print("  Failed: Convergence did not reach")
            return xh, fh, -1, [xa, xb]


    def bisection(self, **args):
        self.update(**args)
        func      = self.func
        diff1func = self.diff1func
        xa        = self.xa
        xb        = self.xb
        dump      = self.dump
        nmaxiter  = self.nmaxiter
        eps       = self.eps
        callback  = self.callback
        IsPrint   = self.isprint
        try:
            dx = args['dx']
        except:
            dx = 0.1 * abs(xb - xa)

        for i in range(10):
            fa = func(xa)
            fb = func(xb)
            if fa * fb > 0.0:
                xa -= dx
                xb += dx
            else:
                break
        else:
            print("tkaequation.bisection:: Error: Initial xa and xb should be chosen as f(xa)*f(xb) < 0")
            print("     xa={:12.6g}  f(xa)={:12.6g}".format(xa, fa))
            print("     xb={:12.6g}  f(xb)={:12.6g}".format(xb, fb))
            return 1.0e10, 1.0e10, -2, [xa, xb]

        for i in range(nmaxiter):
            xhalf = (xa + xb) / 2.0
            fhalf = func(xhalf)

            if IsPrint == 1:
                print("  Iter {}: xhalf = {:12.4g}  f(xhalf) = {:12.4g}".format(i, xhalf, fhalf))
            if abs(xa - xhalf) < eps and abs(xb - xhalf) < eps:
                if IsPrint == 1:
                    print("  Success: Convergence reached at x = {}".format(xhalf))
                return xhalf, fhalf, i, [xa, xb]
            if fa * fhalf < 0.0:
                xb = xhalf
                fb = fhalf
            else:
                xa = xhalf
                fa = fhalf

            self.iter  = i
            self.x      = xhalf
            self.f      = fhalf
            self.xa     = xa
            self.fa     = fa
            self.xb     = xb
            self.fb     = fb
            self.xhalf     = xhalf
            self.fhalf     = fhalf
            self.dx = xb - xa

            if callback is not None:
                if callback(self) == 0:
                    break

        else:
            if IsPrint == 1:
                print("  Failed: Convergence did not reach")
                return xhalf, fhalf, -1, [xa, xb]


    def newton1d(self, **args):
        self.update(**args)
        func      = self.func
        diff1func = self.diff1func
        x0        = self.x0
        dump      = self.dump
        nmaxiter  = self.nmaxiter
        eps       = self.eps
        callback  = self.callback
        IsPrint   = self.isprint

        x = x0
        for i in range(nmaxiter):
            f  = func(x)
            f1 = diff1func(x)
            if f1 < 0.0:
                f1 -= dump
            else:
                f1 += dump
            xnext = x - f / f1
            dx = xnext - x

            self.iter  = i
            self.x     = x
            self.f     = f
            self.f1     = f1
            self.xnext = xnext
            self.dx = dx

            if callback is not None:
                if callback(self) == 0:
                    break

            if IsPrint == 1:
                print("Iter {:5d}: x: {:>16.12f} => {:>16.12f}, dx = {:>10.4g}".format(i, x, xnext, dx))
            if abs(dx) < eps:
                if IsPrint == 1:
                    print("  Success: Convergence reached: dx = {} < eps = {}".format(dx, eps))
                x = (x + xnext) / 2.0
                f = func(x)
                return x, f, i, dx
            x = xnext
        else:
            if IsPrint == 1:
                print("  Failed: Convergence did not reach: dx = {} > eps = {}".format(dx, eps))
            return x, f, -1, dx


    def scf(self, **args):
        self.update(**args)
        func     = self.func
        x0       = self.x0
        kmix     = self.kmix
        nmaxiter = self.nmaxiter
        eps      = self.eps
        callback = self.callback
        IsPrint  = self.isprint
        x = x0

        f = func(x)
        for i in range(nmaxiter):
            x1  = f
            dx1 = x1 - x
            xnext = (1.0 - kmix) * x + kmix * x1
            dx    = xnext - x

            self.iter  = i
            self.x     = x
            self.f     = f
            self.x1 = x1
            self.dx1 = dx1
            self.xnext = xnext
            self.dx = dx

            if callback is not None:
                if callback(self) == 0:
                    break

            if IsPrint == 1:
                print("Iter {:5d}: x: {:>16.12f} => {:>16.12f}, dx = {:>10.4g}".format(i, x, xnext, dx))

            if abs(dx) < eps:
                self.status     = 1
                self.status_str = 'Converged'
                if IsPrint == 1:
                    print("  Success: Convergence reached: dx = {} < eps = {}".format(dx, eps))
                x = (x + xnext) / 2.0
                f = func(x)
                return x, f, i, dx

            x = xnext
            f = func(x)

        self.status     = 0
        self.status_str = 'Not converged'
        if IsPrint == 1:
            print("  Failed: Convergence did not reach: dx = {} > eps = {}".format(dx, eps))
        return x, f, -1, dx


