import sys
import numpy as np
from numpy import exp
import pandas as pd
import scipy.signal
import matplotlib.pyplot as plt
import matplotlib.widgets as wg


from tklib.tkutils import replace_path, getarg, getintarg, getfloatarg, pint, pfloat
from tklib.tkparams import tkParams
from tklib.tkvariousdata import tkVariousData
from tklib.tkapplication import tkApplication


"""
Peak search
"""


#=============================
# parameters
#=============================
def usage(app):
    print("usage: python {} args".format(sys.argv[0]))

def initialize(app):
    app.cparams = tkParams()
    cparams = app.cparams

    cparams.mode = 'peak search'
    
    cparams.infile = 'test/xrd.xlsx'
    cparams.xmin = -1e100
    cparams.xmax =  1e100
    cparams.xlabel = '0'
    cparams.ylabel = '1'

    cparams.threshold        = 100.0
    cparams.ydiff1_threshold = 1.0e-2
    cparams.norder           = 4
    cparams.nsmooth          = 11

    cparams.figsize      = [12, 8]
    cparams.figsize_test = [12, 8]
    cparams.fontsize = 16
    cparams.legend_fontsize = 14
    cparams.fontsize_test = 12
    cparams.legend_fontsize_test = 12

def update_vars(app):
    cparams = app.cparams

    cparams.mode      = getarg     (1, cparams.mode)
    cparams.infile    = getarg     (2, cparams.infile)
    cparams.xmin      = getfloatarg(3, cparams.xmin)
    cparams.xmax      = getfloatarg(4, cparams.xmax)
    cparams.xlabel    = getarg     (5, cparams.xlabel)
    cparams.ylabel    = getarg     (6, cparams.ylabel)
    cparams.threshold = getfloatarg(7, cparams.threshold)
    cparams.ydiff1_threshold = getfloatarg( 8, cparams.ydiff1_threshold)
    cparams.norder           = getintarg  ( 9, cparams.norder)
    cparams.nsmooth          = getintarg  (10, cparams.nsmooth)

def peak_search(x, y, nsmooth, norder, threshold, ydiff1_threshold, is_print = False):
    inf = {}
    inf = {}
    
    h = x[1] - x[0]
    ysmooth = scipy.signal.savgol_filter(y, nsmooth, norder, deriv = 0)
    ydiff1  = scipy.signal.savgol_filter(y, nsmooth, norder, deriv = 1) / h
    ydiff2  = scipy.signal.savgol_filter(ydiff1, nsmooth, norder, deriv = 1) / h
    ydiff3  = scipy.signal.savgol_filter(ydiff2, nsmooth, norder, deriv = 1) / h
    ydiff3  = scipy.signal.savgol_filter(ydiff3, nsmooth, norder, deriv = 0)

    diff1_ratio = []
    for i in range(len(x)):
        if ysmooth[i] < threshold:
            diff1_ratio.append(abs(ydiff1[i] / (ysmooth[i] + 1.0e-5)))
        else:
            diff1_ratio.append(abs(ydiff1[i] / ysmooth[i]))
    max_diff1_ratio = max(diff1_ratio)
    diff1_ratio_th = max_diff1_ratio * ydiff1_threshold

    def find_previous_zero(x, y, i0):
        x0 = x[i0]
        y0 = y[i0]
        for i in range(len(x)-1, -1, -1):
            if y0 * y[i] <= 0.0:
                return i, x[i]
        else:
            return None, None

    def find_next_zero(x, y, i0):
        x0 = x[i0]
        y0 = y[i0]
        for i in range(i0+1, len(x)):
            if y0 * y[i] <= 0.0:
                return i, x[i]
        else:
            return None, None

    def find_zeros(x, y):
        i0 = 0
        xpeaks = []
        while 1:
            inext, xpeak = find_next_zero(x, y, i0)
            if inext is None:
                break

            ytop  = ysmooth[inext]
            diff1 = abs(ydiff1[inext])
            diff1_ratio = diff1 / ytop
            diff2 = ydiff2[inext]
            if is_print:
                print(f"x={xpeak:8.3g} ytop={ytop:8g} >? {threshold:8g} "
                        + f"|dy/dx|/ytop={diff1_ratio:8g} <? {diff1_ratio_th:8g}", end = '')
            if ytop < threshold:
                if is_print:
                    print("   too weak: excluded")
            elif diff1 / ytop > diff1_ratio_th:
                if is_print:
                    print("   |dy/dx| / ytop too large: excluded")
            elif 0.0  < diff2:
                if is_print:
                    print("   minimum: excluded")
            else:
                if is_print:
                    print("   marked as a peak")

                ip, xm = find_next_zero(x, y, inext)
                im, xm = find_previous_zero(x, y, inext)
                if ip - inext > inext - im:
                    xpeaks.append([inext, xpeak, x[ip] - x[inext]])
                else:
                    xpeaks.append([inext, xpeak, x[inext] - x[im]])

            i0 = inext

        return xpeaks
    
    xpeaks = find_zeros(x, ydiff3)
#    print("x=", xpeaks)

    inf["ysmooth"] = ysmooth
    inf["ydiff1"]  = ydiff1
    inf["ydiff2"]  = ydiff2
    inf["ydiff3"]  = ydiff3
    inf["ydiff3"]  = ydiff3
    inf["max_diff1_ratio"] = max_diff1_ratio
    inf["diff1_ratio_th"]  = diff1_ratio_th

    return xpeaks, inf

def exec_peak_search(app):
    cparams = app.cparams

    cparams.logfile = app.replace_path(cparams.infile)
    cparams.outfile = app.replace_path(cparams.infile, template = ["{dirname}", "{filebody}-searched.xlsx"])
    print(f"Open logfile [{cparams.logfile}]")
    app.redirect(targets = ["stdout", cparams.logfile], mode = 'w')

    print("")
    print(f"Peak search in the data [{cparams.infile}]")
    print("mode    : ", cparams.mode)
    print("infile  : ", cparams.infile)
    print("xlabel  : ", cparams.xlabel)
    print("ylabel  : ", cparams.ylabel)
    print("  x range : ", cparams.xmin, cparams.xmax)
    print("norder    : ", cparams.norder)
    print("nsmooth   : ", cparams.nsmooth)
    if cparams.nsmooth % 2 == 0:
        cparams.nsmooth += 1
        print(f"  Warning: nsmooth must be odd. Changed to {cparams.nsmooth}")
    print("threshold : ", cparams.threshold)
    print("dy/dx threshold : ", cparams.ydiff1_threshold)
    
    if type(cparams.xlabel) is str and '*** choose ***' in cparams.xlabel:
        app.terminate(f"\nError: Choose x label\nTerminated.\n", usage = usage, pause = True)
    if type(cparams.ylabel) is str and '*** choose ***' in cparams.ylabel:
        app.terminate(f"\nError: Choose y label\nTerminated.\n", usage = usage, pause = True)
    if cparams.nsmooth < cparams.norder:
        app.terminate(f"\nError: nsmooth must be larger than norder\n", usage = usage, pause = True)

    print("Read data from [{}]".format(cparams.infile))
    datafile = tkVariousData(cparams.infile)
    header, datalist = datafile.Read_minimum_matrix(close_fp = True)#, usage = usage)
    xlabel, _xin = datafile.FindDataArray(cparams.xlabel, flag = 'i')
    ylabel, _yin = datafile.FindDataArray(cparams.ylabel, flag = 'i')
    if _xin is None:
        app.terminate(f"Error: Can not find the data with [{cparams.xlabel}]", usage = usage, pause = True)
    if _yin is None:
        app.terminate(f"Error: Can not find the data with [{cparams.xlabel}]", usage = usage, pause = True)

    x = []
    y = []
    for i in range(len(_xin)):
        if not (cparams.xmin is None or cparams.xmin == '' or cparams.xmin == '*') and _xin[i] < cparams.xmin:
            continue
        if not (cparams.xmax is None or cparams.xmax == '' or cparams.xmax == '*') and cparams.xmax < _xin[i]:
            continue
        x.append(_xin[i])
        y.append(_yin[i])
    ndata = len(x)
    print("  ndata = ", ndata)

    xpeaks, inf = peak_search(x, y, cparams.nsmooth, cparams.norder, 
                                cparams.threshold, cparams.ydiff1_threshold, is_print = True)
    ysmooth = inf["ysmooth"]
    ydiff1  = inf["ydiff1"]
    ydiff2  = inf["ydiff2"]
    ydiff3  = inf["ydiff3"]

#=============================
# prepare graph
#=============================
    def Gauss(x, x0, a, I0):
        X = (x - x0) / a
        return I0 * exp(-X * X)

    print("")
    print("plot")
    ndata = len(x)
    dx = x[1] - x[0]

    def plot_input(ax_input):
        maxI = max(y)
        bar_range = [-0.05 * maxI, -0.01 * maxI]

        ax_input.plot(x, y,       label = 'input',      linestyle = '', marker = 'o', markersize = 0.5, markerfacecolor = 'black', markeredgecolor = 'black')
        ax_input.plot(x, ysmooth, label = 'smoothened', linestyle = '-', color = 'blue', linewidth = 0.5)
        ax_input.plot(ax_input.get_xlim(), [0.0, 0.0], linestyle = 'dashed', color = 'red', linewidth = 0.5)
        ylim = ax_input.get_ylim()
        for i in range(len(xpeaks)):
            idx = xpeaks[i][0]
            _x = x[idx]
            _I = ysmooth[idx]
            _w = xpeaks[i][2]
            _Ihalf = _I / 2.0
            a_g = _w / 0.832554611
            nx = int(_w / dx * 3.0 + 1.00001)
            xx = [x[i1] for i1 in range(max([0, idx - nx]), min([idx + nx, ndata]))]
            gf = [Gauss(xx[i1], _x, a_g, _I) for i1 in range(len(xx))]
            ax_input.plot([_x, _x], bar_range, linestyle = '-', color = 'black', linewidth = 0.5)
            ax_input.plot([_x - _w, _x + _w], [_Ihalf, _Ihalf], linestyle = '-', color = 'green', linewidth = 0.5)
            ax_input.plot(xx, gf, linestyle = '-', color = 'red', linewidth = 0.5)
#        ax_input.set_xlabel(xlabel, fontsize = cparams.fontsize)
        ax_input.set_ylabel(ylabel, fontsize = cparams.fontsize)
        ax_input.legend(fontsize = cparams.legend_fontsize)

    if cparams.mode == 'peak search':
        fig, axes = plt.subplots(1, 1, sharex = 'all', figsize = cparams.figsize)
        axes.tick_params(labelsize = cparams.fontsize)
        plot_input(axes)
        axes.set_xlabel(xlabel, fontsize = cparams.fontsize)
    else:
        fig, axes = plt.subplots(4, 1, sharex = 'all', figsize = cparams.figsize_test)
#        axes = axes.flatten()
        axes[0].tick_params(labelsize = cparams.fontsize_test)
        axes[1].tick_params(labelsize = cparams.fontsize_test)
#       axes[2].tick_params(labelsize = cparams.fontsize_test)
#       axes[3].tick_params(labelsize = cparams.fontsize_test)
#       ax2 = axes[0].twinx()
#       ax3 = axes[1].twinx()
    
        ax_input = axes[0]
        ax_diff3 = axes[1]
        ax_diff1 = axes[2]
        ax_diff2 = axes[3]

        plot_input(ax_input)

        ax_diff3.plot(x, ydiff3, linestyle = '-', color = 'green', linewidth = 0.5)
        ax_diff3.plot(axes[0].get_xlim(), [0.0, 0.0], linestyle = 'dashed', color = 'red', linewidth = 0.5)
#        ax_diff3.set_xlabel(xlabel, fontsize = cparams.fontsize_test)
        ax_diff3.set_ylabel('Third differential', fontsize = cparams.fontsize_test)

        ax_diff1.plot(x, ydiff1)
        ax_diff1.plot(axes[0].get_xlim(), [0.0, 0.0], linestyle = 'dashed', color = 'red', linewidth = 0.5)
#        ax_diff1.set_xlabel(xlabel, fontsize = cparams.fontsize_test)
        ax_diff1.set_ylabel('First differential', fontsize = cparams.fontsize_test)

        ax_diff2.plot(x, ydiff2, linestyle = '-', color = 'green', linewidth = 0.5)
        ax_diff2.plot(axes[0].get_xlim(), [0.0, 0.0], linestyle = 'dashed', color = 'red', linewidth = 0.5)
        ax_diff2.set_xlabel(xlabel, fontsize = cparams.fontsize_test)
        ax_diff2.set_ylabel('Second differential', fontsize = cparams.fontsize_test)

    plt.tight_layout()

    plt.pause(0.1)
    app.terminate("Press ENTER to exit>>", usage = usage, pause = True)

    exit()


def main(app):
    cparams = app.cparams

    exec_peak_search(app)


if __name__ == '__main__':
    app = tkApplication()

    print(f"Initialize parameters")
    initialize(app)
    print(f"Update parameters by command-line arguments")
    update_vars(app)

    main(app)
 