import sys
import csv
import numpy as np
from numpy import exp, log, log10, sin, cos, tan, arcsin, arccos, arctan, sinh, cosh, tanh, sqrt, abs
import pandas as pd
from scipy import optimize
import matplotlib.pyplot as plt


from tklib.tksci.tksci import h, h_bar, hbar, e, kB, NA, c, pi, pi2, torad, todeg, basee
from tklib.tksci.tksci import me, mp, mn, u0, e0, e2_4pie0, a0, R, F, g
from tklib.tksci.tksci import acos, asin, atan, cosh, sinh, tanh
from tklib.tksci.tksci import degcos, degsin, degtan, degacos, degasin, degatan
from tklib.tksci.tksci import eVTonm, nmToeV
from tklib.tksci.tksci import factorial, Factorize, Gaussian, Lorentzian, combination, gamma


#==========================================
# Source parameters to be fitted
#==========================================
func = '2.0 * (x - 3.0)**2 + (x - 1.0)**4 + 2.0'

x0    = 0.0
h     = 0.01
alpha = 0.0

nmaxiter = 100
tol = 1.0e-5


#==========================================
# Graph parameters
#==========================================
fplot  = 1
ngdata = 201
xgmin  = -4.0
xgmax  =  4.0
tsleep = 0.3


argv = sys.argv
n = len(argv)
if n >= 2:
    method = argv[1]
if n >= 3:
    func = argv[2]
if n >= 4:
    x0 = float(argv[3])
if n >= 5:
    h = float(argv[4])
if n >= 6:
    alpha = float(argv[5])
if n >= 7:
    maxiter = int(argv[6])
if n >= 8:
    tol = float(argv[7])
if n >= 9:
    xgmin = float(argv[8])
if n >= 10:
    xgmax = float(argv[9])


def brent(func, diff1func = None, xa = 10.0, xb = 10.0, dump = 0.0, nmaxiter = 100, eps = 1.0e-5, delta = 1.0e-5, callback = None, IsPrint = True):
        fa = func(xa)
        fb = func(xb)
        if fa * fb > 0.0:
            if IsPrint == 1:
                print("breant:: Error: " \
                      "Initial xa and xb should be chosen as f(xa)*f(xb) < 0")
                print("     xa={:12.6g}  f(xa)={:12.6g}".format(xa, fa))
                print("     xb={:12.6g}  f(xb)={:12.6g}".format(xb, fb))
            return 1.0e10, 1.0e10, -2, [xa, xb]

        if abs(fa) < abs(fb):
            xa, xb = [xb, xa]
            fa, fb = [fb, fa]
        if IsPrint == 1:
            print("  xa = {:12.8f}  fa = {:12.4g}".format(xa, fa))
            print("  xb = {:12.8f}  fb = {:12.4g}".format(xb, fb))

        xc = xa
        mflag = 1

        for i in range(nmaxiter):
            if abs(xa - xb) < eps:
                xh = (xa + xb) / 2.0
                fh = func(xh)
                if IsPrint == 1:
                    print("  Success: Convergence reached at x = {}".format(xh))
                return xh, fh, i, [xa, xb]

            fc = func(xc)
            if abs(xa - xc) > 1.0e-10 and abs(xb - xc) > 1.0e-10:
#inverse quadratic interpolation
                xs = xa * fb * fc / (fa - fb) / (fa - fc) \
                   + xb * fa * fc / (fb - fa) / (fb - fc) \
                   + xc * fa * fb / (fc - fa) / (fc - fb)   
            else:
#secant method
                xs = xb - xb * (xb - xa) / (fb - fa)

            x4 = (3.0 * xa + xb) / 4.0
            if not (x4 < xs < xb or xb < xs < x4):
                mflag = 1
            elif mflag == 1 and abs(xs - xb) >= abs(xb - xc) / 2.0:
                mflag = 1
            elif mflag == 0 and abs(xs - xb) >= abs(xc - xd) / 2.0:
                mflag = 1
            elif mflag == 1 and abs(xb - xc) < delta:
                mflag = 1
            elif mflag == 0 and abs(xc - xd) < delta:
                mflag = 1
            else:
                mflag = 0

            if mflag == 1:
                xs = (xa + xb) / 2.0  # bisection

            fs = func(xs)
            xd = xc
            fd = fc
            xc = xb

            if fa * fs < 0.0:
                xb = xs
                fb = fs
            else:
                xa = xs
                fa = fs
            if abs(fa) < abs(fb):
                xa, xb = [xb, xa]
                fa, fb = [fb, fa]

            fs  = func(xs)

            if callback is not None:
                if callback(x) == 0:
                    break

            if IsPrint == 1:
                print("  Iter {}: xa,b = {:12.8f} - {:12.8f}  fa,b = {:12.4g} - {:12.8g}  mflag = {}"
                    .format(i, xa, xb, fa, fb, mflag))
                print("     xs = {:12.8f}  fs = {:12.4g}".format(xs, fs))

        else:
            if IsPrint == 1:
                print("  Failed: Convergence did not reach")
            return xh, fh, -1, [xa, xb]


def bisection(func, diff1func = None, xa = 10.0, xb = 10.0, dump = 0.0, nmaxiter = 100, eps = 1.0e-5, delta = 1.0e-5, callback = None, IsPrint = True):
        fa = func(xa)
        fb = func(xb)
        if fa * fb > 0.0:
            print("bisection:: Error: Initial xa and xb should be chosen as f(xa)*f(xb) < 0")
            print("     xa={:12.6g}  f(xa)={:12.6g}".format(xa, fa))
            print("     xb={:12.6g}  f(xb)={:12.6g}".format(xb, fb))
            return 1.0e10, 1.0e10, -2, [xa, xb]

        for i in range(nmaxiter):
            xhalf = (xa + xb) / 2.0
            fhalf = func(xhalf)

            if IsPrint == 1:
                print("  Iter {}: xhalf = {:12.4g}  f(xhalf) = {:12.4g}".format(i, xhalf, fhalf))
            if abs(xa - xhalf) < eps and abs(xb - xhalf) < eps:
                if IsPrint == 1:
                    print("  Success: Convergence reached at x = {}".format(xhalf))
                return xhalf, fhalf, i, [xa, xb]
            if fa * fhalf < 0.0:
                xb = xhalf
                fb = fhalf
            else:
                xa = xhalf
                fa = fhalf

            if callback is not None:
                if callback(xhalf) == 0:
                    break

        else:
            if IsPrint == 1:
                print("  Failed: Convergence did not reach")
                return xhalf, fhalf, -1, [xa, xb]

def newton1d(func, diff1func = None, xa = 10.0, xb = None, dump = 0.0, nmaxiter = 100, eps = 1.0e-5, delta = 1.0e-5, callback = None, IsPrint = True):
        x = xa
        for i in range(nmaxiter):
            f  = func(x)
            f1 = diff1func(x)
            if f1 < 0.0:
                f1 -= dump
            else:
                f1 += dump

            if abs(f1) < 1.0e-10:
                f1 = 1.0

            xnext = x - f / f1
            dx = xnext - x

            if callback is not None:
                if callback(x) == 0:
                    break

            if IsPrint == 1:
                print("Iter {:5d}: x: {:>16.12f} => {:>16.12f}, dx = {:>10.4g}".format(i, x, xnext, dx))

            if abs(dx) < eps:
                if IsPrint == 1:
                    print("  Success: Convergence reached: dx = {} < eps = {}".format(dx, eps))
                x = (x + xnext) / 2.0
                f = func(x)

                return x, f, i, dx

            x = xnext
        else:
            if IsPrint == 1:
                print("  Failed: Convergence did not reach: dx = {} > eps = {}".format(dx, eps))

            return x, f, -1, dx


def target_func(x, x2 = 0):
    y = eval(func, globals(), {"x": x})
    return y

# １次微分を定義するとcgやbfgsなどの勾配法を使える
def diff1(x):
    ym = target_func(x - h)
    yp = target_func(x + h)
    diff = (yp - ym) / 2.0 / h

    return diff


# callbackを使うと、最適化過程をモニターできる
# 引数は変数のリストだけが渡される
# 反復回数などはglobal変数で保持する
iter = 0
ax   = None
xg   = None
def callback(x):
    global iter

    f = target_func(x)
#    print(f"callback {iter}: x={x}  func={f}")
    iter += 1

# グラフに更新した点を追加
    if fplot == 1:
        ax.plot(x, f, linestyle = '', marker = 'o', markersize = 5.0, markeredgecolor = 'red', markerfacecolor = 'red')
        plt.pause(tsleep)

    return 1


#==========================================
# Main routine
#==========================================
def main():
    global x0, tol, maxiter
    global ax, xg

    print("")
    print("Solve given function")
    print(f"method {method}")
    print(f"func        {func}")
    print(f"graph range {xgmin} - {xgmax}")
    print(f"initial x   {x0}")
    print(f"alpha       {alpha}")
    print(f"h           {h}")
    print(f"tol         {tol}")
    print(f"nmaxiter    {nmaxiter}")

    if '***' in func:
        input("\nError: Choose frunc >> ")
        exit()
    if '***' in method:
        input("\nError: Choose method >> ")
        exit()

# グラフに表示する関数
    xgstep = (xgmax - xgmin) / (ngdata - 1)
    xg = [xgmin + xgstep * i for i in range(ngdata)]
    yini = [target_func(_x) for _x in xg]
#    for i in range(len(xg)):
#        print(f"  {xg[i]:8.3g}  {yini[i]:10.3g}")

#関数のグラフ
    if fplot == 1:
        figure = plt.figure(figsize = (5, 5))
        ax = figure.add_subplot(1, 1, 1)
        ax.plot(xg, yini, color = 'black', linestyle = '-', linewidth = 0.5)
        ax.plot(x0, target_func(x0), linestyle = '', marker = 'o', markersize = 8.0, markeredgecolor = 'black', markerfacecolor = 'black')
        ax.plot([xgmin, xgmax], [0.0, 0.0], linestyle = 'dashed', linewidth = 0.5, color = 'red')

        plt.pause(0.001)

    print("")
    print(f"Solve [{func}] by [{method}]:")
    if method == 'brent':
        x, f, ret, range_fin = brent(target_func, xa = xgmin, xb = xgmax, nmaxiter = nmaxiter, eps = tol, delta = tol, callback = callback, IsPrint = 1)
    elif method == 'bisection':
        x, f, ret, range_fin = bisection(target_func, xa = xgmin, xb = xgmax, nmaxiter = nmaxiter, eps = tol, delta = tol, callback = callback, IsPrint = 1)
    elif method == 'newton':
        x, f, ret, range_fin = newton1d(target_func, diff1func = diff1, xa = x0, dump = alpha, nmaxiter = nmaxiter, eps = tol, delta = tol, callback = callback, IsPrint = 1)
    else:
        print("")
        print(f"Error: Can not find the method [{method}]")
        exit()

    if ret == -2:
        print("")
        print("Error: Initial range invalid")
    else:
        ax.plot(x, f, linestyle = '', marker = 's', markersize = 12.0, markeredgecolor = 'blue', markerfacecolor = 'blue')
        plt.pause(0.001)

        print("")
        print(f"x = {x}")
        print(f"y = {f}")
        print(f"range = {range_fin}")
    
    print("Press ENTER to terminate:", end = '')
    ret = input()


if __name__ == "__main__":
    main()
