from math import exp, log, sqrt
import numpy as np
import scipy.signal


from tklib.tksci.tksci import pi


def Gaussian(x, x0, whalf):
#A = 1/whalf * sqrt(ln2 / pi)
    A = 0.469718639 / whalf
#a = whalf / sqrt(ln2)
    a = whalf / 0.832554611
    X = (x - x0) / a
    return A * exp(-X*X)

def Hij(xstep, Wa, Grange, i, j):
    return Gaussian((j - i) * xstep, 0.0, Wa)


# Make filter function list
def make_filter(Wa, Grange, xstep):    
    ixG0   = int(Grange * Wa / xstep + 1.0001)
    ixGmax = 2 * ixG0
    nxGmax = ixG0 + 1
    xG0   = ixG0 * xstep

    xG = []
    yG = []
    for i in range(ixGmax+1):
        x = i * xstep
        xG.append(x)
        yG.append(Gaussian(x, xG0, Wa))

    SG = 0.0
    for i in range(len(yG)-1):
        SG += (yG[i] + yG[i+1]) / 2.0 * (xG[i+1] - xG[i])

    for i in range(ixGmax+1):
        yG[i] /= SG

    print("   Range: {} in width".format(Grange * Wa))
    print("   i range: {} - {} at center {}".format(0,ixGmax, xG0))
    print("   ixGmax = ", ixGmax)
    print("   SG = ", SG)

    return xG, yG


def convolve(xraw, yraw, ywf, **kwargs):
    yconv = np.convolve(yraw, ywf, **kwargs) / sum(ywf)
    n_new = len(yconv)
    dn = n_new - len(yraw)
    if dn > 0:
        offset = int(dn / 2)
        xmin = xraw[0]
        xstep = xraw[1] - xmin
        xmin_new = xmin - offset * xstep
        x = np.array([xmin_new + i * xstep for i in range(n_new)])
        return x, yconv
    return xraw, yconv


def convolve_xydata(xlist, ylist, whalf, x0, x1, xstep, **kwargs):
    nx = int( (x1 - x0) / xstep + 1.0e-3)
    xx = [x0 + i * xstep for i in range(nx)]
    yy = [0.0] * nx
    xrange = whalf * 6.0
    irange = int( (xrange / xstep) + 1.0e-3)
#    print("x=", x0, x1, xstep, nx)
#    print("range:", xrange, irange)
    for i in range(len(xlist)):
        xi = xlist[i]
        ic = int( (xi - x0) / xstep + 1.0e-3)
        if ylist is None:
            f = 1.0
        else:
            f = ylist[i]

        for j in range(max(0, ic - irange), min(nx, ic + irange + 1)):
            xj = x0 + j * xstep
            yy[j] += f * Gaussian(xj, xi, whalf)

    return xx, yy


def extend_smooth(x, y, nzero, nlin, xstep = 0.0):
    xmin = x[0]
    xstep = x[1] - x[0]
    xmin_new = x[0] - nzero * xstep
    n_new = nzero + len(x)
    print("extend_smooth:")
    print("  Add {} zeros at top of the data".format(nzero))
    print("    xmin changes: {} => {}".format(xmin, xmin_new))
    print("  Reshape {} input data with a linear filter".format(nlin))

    xx = np.array([xmin_new + i * xstep for i in range(n_new)])
    yy = np.zeros(n_new)
    for i in range(nlin):
        k = i / (nlin - 1)
        yy[i+nzero] = k * y[i]
    for i in range(len(x) - nlin):
        yy[i+nzero+nlin] = y[i+nzero]
    return xx, yy


def deconvolute_fft(xRaw, yRaw, xG, yG):
    k = sum(yG)

    n = len(xRaw)
    nlog = int(log(n) / log(2) + 1.0 - 1.0e-5)
    nfft = pow(2, nlog)

    xmin = xRaw[0]
    xstep = xRaw[1] - xmin
    xRawFFT = [xmin + i * xstep for i in range(nfft)]
    yRawFFT = np.insert(yRaw, len(yRaw), np.zeros(nfft - n))
# filterの中心位置の原点からのずれによって、iFFT後の原点がずれる
    yGFFT   = np.insert(yG, len(yG), np.zeros(nfft - len(yG)))
    xminG = xmin + len(xG) / 2 * xstep 
    xGFFT = [xminG + i * xstep for i in range(nfft)]
#    nadd = int((nfft - len(yG)) / 2)
#    yGFFT   = np.insert(yG, len(yG), np.zeros(nadd))
#    yGFFT   = np.insert(yGFFT, 0, np.zeros(nfft - len(yGFFT)))

    yRawFFTed = np.fft.fft(yRawFFT)
    yGFFTed   = np.fft.fft(yGFFT)
    ycFFTed = yRawFFTed / yGFFTed
    ydeconv = np.fft.ifft(ycFFTed)
    ydeconv = [float(ydeconv[i]) for i in range(len(ydeconv))]

    return xGFFT, ydeconv, xRawFFT, yRawFFT, xRawFFT, yGFFT


def convolve_func(x, width, func_type = 'gauss'):
    if func_type == 'lorentz':
        coeff = 1.0 / width / pi
        dvx = x / width
        return 1.0 / (1.0 + dvx * dvx)
    else:
        coeff = 1.0 / sqrt(pi)/ width
        dvx = x / width
        return coeff * exp(-dvx*dvx)

def convolution(x, y, width, func_type, nskip = None, nxmin_plot = None):
    if width <= 0.0:
        if nskip is None:
            return y
        else:
            return x, y   

    ndata = len(x)
    dx     = x[1] - x[0]
    xrange = x[ndata-1] - x[0]
# integration range, converted to number of the list index
    di = int( (width * 5.0) / dx + 1.1 )
# the coefficient of Gauss function
#    coeff = 1.0 / sqrt(pi)/ width

# deconvoluted data
    if nskip is None:
        ys = [0.0]*ndata
        for j in range(0, ndata):
            x0 = x[j]
            y0 = y[j]
            for k in range(-di, di+1):
                if j+k < 0 or j+k >= ndata:
                    continue
 
                f = dx * convolve_func(dx * k, width, func_type)                
                ys[j+k] += y0 * f
        return ys;
    else:
        if nxmin_plot is not None and int(ndata / nskip) < nxmin_plot:
            nskip = int(ndata / nxmin_plot)

        ys = [0.0]*ndata
        for j in range(0, ndata, nskip):
            x0 = x[j]
            y0 = y[j]
            for k in range(-di, di+1):
                if j+k < 0 or j+k >= ndata or k % nskip != 0:
                    continue
 
                f = dx * convolve_func(dx * k, width, func_type)                
                ys[j+k] += y0 * f

        _x = [x[i]  for i in range(0, ndata, nskip)]
        _y = [ys[i] for i in range(0, ndata, nskip)]
        return _x, _y

def convolute_by_func(x, y, func, x0, width):
    ndata = len(x)
    dx     = x[1] - x[0]
    xrange = x[ndata-1] - x[0]
# integration range, converted to number of the list index
    di = int( (width * 5.0) / dx + 1.1 )
# deconvoluted data
    ys = [0.0]*ndata
    for j in range(0, ndata):
        x0 = x[j];
        y0 = y[j];
        for k in range(-di, di+1):
            if j+k < 0 or j+k >= ndata:
                continue
            dvx = dx * k / width
            f = dx * func(x[j + k], x0, width)
            ys[j+k] += y0 * f

    return ys;

def deconvolute_deconvolve(xRaw, yRaw, xG, yG):
    k = sum(yG)

    IDec, remainder = scipy.signal.deconvolve(yRaw, yG)
    IDec *= k
    ndata = len(xRaw)
    nGhalf = int(len(xG) / 2)

    return xRaw[nGhalf:ndata-nGhalf], IDec


def deconvolute_jacobi(xRaw, yRaw, xG, yG, fig, ax):
    global Wa, Grange

    k = sum(yG)

    print("Deconvolution by Jacobi method")
    print("")

    xstep = xRaw[1] - xRaw[0]

    xgmin = min(xRaw)
    xgmax = max(xRaw)

    n = len(xRaw)
    Sg = np.zeros(n)
    for i in range(n):
        for j in range(n):
            Sg[i] += Hij(xstep, Wa, Grange, i, j)
    print("Filter area w.r.t. i: Sg=", Sg[int(n/2)])

    ymax = max([abs(yRaw[i]) for i in range(n)])
    y     = yRaw.copy()
    yPrev = yRaw.copy()

    for it in range(nmaxiter):
        Hx = np.zeros(n)
        print("iter=", it)

        for i in range(n):
            for j in range(n):
                Hx[i] += Hij(xstep, Wa, Grange, i, j) * yPrev[j]
            h = Hij(xstep, Wa, Grange, i, i)
            y[i] = yPrev[i] + (yRaw[i] - Hx[i] / Sg[i]) / h * dump

#        y = SmoothingBySimpleAverage(y, nsmooth)
        y = SmoothingByPolynomialFit(y, nsmooth)
        if zero_correction:
            for i in range(n):
                if y[i] < 0.0:
                    y[i] = 0.0

        ax[0].cla()

        data1 = ax[0].plot(xRaw, yRaw, label = 'raw/initial')
        data1 = ax[0].plot(xRaw, y, label = 'updated')
#        data4 = ax[2].plot(xG, yG, label = 'filter')
        ax[0].set_xlim([xgmin, xgmax])
        ygmax = max([max(xRaw), max(y)])
        ax[0].set_ylim([0.0, ygmax])
#        ax[1].set_xlim([xgmin, xgmax])
#        ax[1].set_ylim([0.0, max(yRaw)])
#        ax[2].set_xlim([xgmin, xgmax])
#        ax[2].set_ylim([0.0, max(yG)])

        ax[0].legend()
#        ax[2].legend()
        plt.tight_layout()
        plt.pause(sleeptime)

        max_err = max([abs(y[i] - yPrev[i]) for i in range(n)])
        rel_err = max_err / ymax
        print("  max error: ", max_err, "  relative error: ", rel_err, "  eps=", eps)
        if max_err / ymax < eps:
            print("Converged at max_err={} ({} relative) < {}".format(max_err, rel_err, eps))
            break
        
        yPrev = y.copy()
    else:
        print("Not converged")

    return xRaw, y

def deconvolute_gauss_seidel(xRaw, yRaw, xG, yG, fig, ax):
    global Wa, Grange

    k = sum(yG)

    print("Deconvolution by Jacobi method")
    print("")

    xstep = xRaw[1] - xRaw[0]

    xgmin = min(xRaw)
    xgmax = max(xRaw)

    n = len(xRaw)
    Sg = np.zeros(n)
    for i in range(n):
        for j in range(n):
            Sg[i] += Hij(xstep, Wa, Grange, i, j)
    print("Filter area w.r.t. i: Sg=", Sg[int(n/2)])

    ymax = max([abs(yRaw[i]) for i in range(n)])
    y     = yRaw.copy()
    yPrev = yRaw.copy()

    for it in range(nmaxiter):
        Hx    = np.zeros(n)
        print("iter=", it)

        for i in range(n):
            for j in range(i):
                Hx[i] += Hij(xstep, Wa, Grange, i, j) * y[j]
            for j in range(i, n):
                Hx[i] += Hij(xstep, Wa, Grange, i, j) * yPrev[j]
            h = Hij(xstep, Wa, Grange, i, i)
            y[i] = yPrev[i] + (yRaw[i] - Hx[i] / Sg[i]) / h * dump

#        y = SmoothingBySimpleAverage(y, nsmooth)
        y = SmoothingByPolynomialFit(y, nsmooth)
        if zero_correction:
            for i in range(n):
                if y[i] < 0.0:
                    y[i] = 0.0

        ax[0].cla()

        data1 = ax[0].plot(xRaw, yRaw, label = 'raw/initial')
        data1 = ax[0].plot(xRaw, y, label = 'updated')
#        data4 = ax[2].plot(xG, yG, label = 'filter')
        ax[0].set_xlim([xgmin, xgmax])
        ygmax = max([max(xRaw), max(y)])
        ax[0].set_ylim([0.0, ygmax])
#        ax[1].set_xlim([xgmin, xgmax])
#        ax[1].set_ylim([0.0, max(yRaw)])
#        ax[2].set_xlim([xgmin, xgmax])
#        ax[2].set_ylim([0.0, max(yG)])

        ax[0].legend()
#        ax[2].legend()
        plt.tight_layout()
        plt.pause(sleeptime)

        max_err = max([abs(y[i] - yPrev[i]) for i in range(n)])
        rel_err = max_err / ymax
        print("  max error: ", max_err, "  relative error: ", rel_err, "  eps=", eps)
        if max_err / ymax < eps:
            print("Converged at max_err={} ({} relative) < {}".format(max_err, rel_err, eps))
            break
        
        yPrev = y.copy()
    else:
        print("Not converged")

    return xRaw, y


def main():
    pass


if __name__ == '__main__':
    main()



"""
sub Deconvolution
{
	my ($pXObs, $pYObs, $AppFunction, $Wa, $CorrectZero, $Dumping, $nMaxIter, $EPS) = @_;
	$CorrectZero = 1      if(!defined $CorrectZero);
	$Dumping     = 0.2    if(!defined $Dumping);
	$nMaxIter    = 30     if(!defined $nMaxIter);
	$EPS         = 1.0e-3 if(!defined $EPS);

	my $nData = @$pXObs;
	my @Y;
	$Y[0] = [];
	for(my $i = 0 ; $i < $nData ; $i++) {
		$Y[0]->[$i] = $pYObs->[$i];
		$Y[0]->[$i] = 0.0 if($CorrectZero and $Y[0]->[$i] < 0.0);
	}

	my $AppFunc = sub { Sci::Lorentzian(@_); };
	if($AppFunction =~ /^g/i) {
		print "Use Gaussian\n";
		$AppFunc = sub { Sci::Gaussian(@_); };
	}
	else {
		print "Use Lorentian\n";
	}

	my $dX = $pXObs->[1] - $pXObs->[0];
	my $absdX = abs($dX);
	my $hii = &$AppFunc(0.0, 0.0, $Wa) * $absdX;
	my $K = $Dumping / $hii;
print "hii=$hii K=$K\n";
	my $iter = 1;
	for($iter = 1 ; $iter <= $nMaxIter ; $iter++) {
		my $iprev = $iter-1;
		$Y[$iter] = [];
		for(my $i = 0 ; $i < $nData ; $i++) {
			my $xi = $pXObs->[$i];
			my $S0 = 0.0;
			if($i > 0) {
				for(my $j = 0 ; $j <= $i-1 ; $j++) {
					my $xj = $pXObs->[$j];
					my $hij = &$AppFunc($xi, $xj, $Wa) * $absdX;
					$S0 += $hij * $Y[$iter]->[$j];
				}
			}
			for(my $j = $i ; $j < $nData ; $j++) {
				my $xj = $pXObs->[$j];
				my $hij = &$AppFunc($xi, $xj, $Wa) * $absdX;
				$S0 += $hij * $Y[$iprev]->[$j];
			}
			$Y[$iter]->[$i] = $Y[$iprev]->[$i] + $K * ($pYObs->[$i] - $S0);
			$Y[$iter]->[$i] = 0.0 if($CorrectZero and $Y[$iter]->[$i] < 0.0);
		}
		my $diff =DeconvolutionDiff($Y[$iter-1], $Y[$iter]);
print "iter: $iter (sigma2=$diff)\n";
		if($diff < $EPS) {
print "Converged at iter=$iter\n";
			last;
		}
	}
	$iter = $nMaxIter-1 if($iter >= $nMaxIter);
	return ($Y[$iter], @Y);
}

sub DeconvolutionDiff
{
	my ($Y1, $Y2) = @_;
	
	my $Y2sum = 0.0;
	my $Y2diff = 0.0;
	for(my $i = 0 ; $i < @$Y1 ; $i++) {
		my $y1 = $Y1->[$i];
		my $y2 = $Y2->[$i];
		$Y2sum += $y1 * $y1;
		my $d = $y1 - $y2;
		$Y2diff += $d * $d;
	}
	$Y2diff /= $Y2sum if($Y2sum > 0.0);
	return $Y2diff;
}
"""
    