import sys
import os.path
import csv
import re
from math import exp, log
import numpy as np
import scipy.signal
import matplotlib.pyplot as plt


"""
"""


#===================
# Global parametrs
#===================
# File
infile = "PDOS-SnSe.dat"

# Analysis range
xmin = -4.5 # eV
xmax =  2.0 # eV


def read_data(infile, xmin = None, xmax = None):
    try:
        infp = open(infile, "r")
    except:
        print("Error: Can not read [{}]".format(infile))
        exit()

    header = []
    data = []
    i = 0
    for line in infp:
        if i == 0:
            header = line.split()
            for j in range(len(header)):
                data.append([])
        else:
            a = line.split()
            a[0] = float(a[0])
            if a[0] < xmin or xmax < a[0]:
                continue
            for j in range(len(a)):
                data[j].append(float(a[j]))
        i += 1    
    infp.close()

    return header, data

header, data = read_data(infile, xmin, xmax)
print("header=", header)

fig = plt.figure(figsize = (15,7))
ax = []
ax.append(fig.add_subplot(1, 1, 1))

for i in range(1, len(data)):
    print("header=", header[i])
    ax[0].plot(data[0], data[i], label = header[i])
ax[0].legend()

plt.pause(0.001)
print("Press ENTER to exit>>", end = '')
input()

exit()















# mode = [gs|jacobi|convolve|fft]
mode = 'gs'

# Filter parameters for Gauss function
Wa     = 0.121223  # eV
Grange = 2.0

# for update graph
sleeptime = 0.001

# for Jacobi/Gauss-Seidel method
dump     = 1.0
nmaxiter = 300
eps      = 1.0e-4
nsmooth  = 5

zero_correction = 0

# only for 'convolve': convmode = [same|full|]
convmode = 'same'
# smoothmode = [convolve|extend|average]
smoothmode = 'convolve+extend'

# only for 'convolve' and 'fft': 'exnted data' parameters
kzero = 5
klin  = 5


#===================
# 起動時引数
#===================
argv = sys.argv
scriptname = argv[0]
def usage():
    print("usage: python {} file mode convmode smoothmode xmin xmax Wa Grange kzero klin".format(scriptname))
    print("       mode = [convolve|fft]")
    print("       convmode = [''|full|same]")
    print("       smoothmode = combination of [none|convolve|extend|average]")
    print("  ex.: python {} SnSe-DOS.csv convolve same convolve+extend -4.5 2.0 0.12 2.0 1 5".format(scriptname))
    print("  ex.: python {} SnSe-DOS.csv fft full convolve+extend -4.5 2.0 0.12 2.0 5 5".format(scriptname))
    print("")
    print("usage: python {} file mode xmin xmax Wa dump nmaxiter eps nsmooth zeroc".format(scriptname))
    print("       mode = [gs|jacobi]   gs=Gauss-Seidel method   jacobi=Jacobi method")
    print("       zeroc = [0|1]  zero correction after a Jacobi/Gauss-Seidel cycle]")
    print("  ex.: python {} pes.csv gs -6.0 2.0 0.12 1.0 300 1.0e-4 5 0".format(scriptname))
    print("")

if len(argv) == 1:
    usage()
    exit()

if len(argv) >= 2:
    infile = argv[1]
if len(argv) >= 3:
    mode = argv[2]

if mode == 'convolve' or mode == 'fft':
    if len(argv) >= 4:
        convmode = argv[3]
    if len(argv) >= 5:
        smoothmode = argv[4]
    if len(argv) >= 6:
        xmin = float(argv[5])
    if len(argv) >= 7:
        xmax = float(argv[6])
    if len(argv) >= 8:
        Wa = float(argv[7])
    if len(argv) >= 9:
        Grange = float(argv[8])
    if len(argv) >= 10:
        kzero = float(argv[9])
    if len(argv) >= 11:
        klin = float(argv[10])
elif mode == 'gs' or mode == 'jacobi':
    convmode   = ''
    smoothmode = ''
    if len(argv) >= 4:
        xmin = float(argv[3])
    if len(argv) >= 5:
        xmax = float(argv[4])
    if len(argv) >= 6:
        Wa = float(argv[5])
    if len(argv) >= 7:
        dump = float(argv[6])
    if len(argv) >= 8:
        nmaxiter = int(argv[7])
    if len(argv) >= 9:
        eps = float(argv[8])
    if len(argv) >= 10:
        nsmooth = int(argv[9])
    if len(argv) >= 11:
        zero_correction = int(argv[10])
else:
    print("")
    print("Error: Invalid mode=[{}]".format(mode))
    usage()
    exit()

header, ext = os.path.splitext(infile)
outcsvfile = header + '-deconvoluted.csv'


def savecsv(outfile, header, datalist):
    try: 
        print("Write to [{}]".format(outfile))
        f = open(outfile, 'w')
    except:
#    except IOError:
        print("Error: Can not write to [{}]".format(outfile))
    else:
        fout = csv.writer(f, lineterminator='\n')
        fout.writerow(header)
#        fout.writerows(data)
        for i in range(0, len(datalist[0])):
            a = []
            for j in range(len(datalist)):
                a.append(datalist[j][i])
            fout.writerow(a)
        f.close()


def Gaussian(x, x0, whalf):
#A = 1/whalf * sqrt(ln2 / pi)
    A = 0.469718639 / whalf
#a = whalf / sqrt(ln2)
    a = whalf / 0.832554611
    X = (x - x0) / a
    return A * exp(-X*X)

def Hij(xstep, Wa, Grange, i, j):
#    ixG0 = int(Grange * Wa / xstep + 1.0001)
#    if abs(j - i) > ixG0:
#        return 0.0
    return Gaussian((j - i) * xstep, 0.0, Wa)

def make_wf(Wa, Grange, xstep):    
    ixG0   = int(Grange * Wa / xstep + 1.0001)
    ixGmax = 2 * ixG0
    nxGmax = ixG0 + 1
    xG0   = ixG0 * xstep

    xG = []
    yG = []
    for i in range(ixGmax+1):
        x = i * xstep
        xG.append(x)
        yG.append(Gaussian(x, xG0, Wa))

    SG = 0.0
    for i in range(len(yG)-1):
        SG += (yG[i] + yG[i+1]) / 2.0 * (xG[i+1] - xG[i])

    for i in range(ixGmax+1):
        yG[i] /= SG

    print("   Range: {} in width".format(Grange * Wa))
    print("   i range: {} - {} at center {}".format(0,ixGmax, xG0))
    print("   ixGmax = ", ixGmax)
    print("   SG = ", SG)

    return xG, yG


def convolve(xraw, yraw, ywf, **kwargs):
    yconv = np.convolve(yraw, ywf, **kwargs) / sum(ywf)
    n_new = len(yconv)
    dn = n_new - len(yraw)
    if dn > 0:
        offset = int(dn / 2)
        xmin = xraw[0]
        xstep = xraw[1] - xmin
        xmin_new = xmin - offset * xstep
        print("convolve: the length of the output data has been changed "
                + "from {} to {}".format(len(yraw), n_new))
        print("  Add offset = {}".format(offset))
        print("  xmin changes: {} => {}".format(xmin, xmin_new))
        x = np.array([xmin_new + i * xstep for i in range(n_new)])
        return x, yconv
    return xraw, yconv

def extend_smooth(x, y, nzero, nlin, xstep = 0.0):
    xmin = x[0]
    xstep = x[1] - x[0]
    xmin_new = x[0] - nzero * xstep
    n_new = nzero + len(x)
    print("extend_smooth:")
    print("  Add {} zeros at top of the data".format(nzero))
    print("    xmin changes: {} => {}".format(xmin, xmin_new))
    print("  Reshape {} input data with a linear filter".format(nlin))

    xx = np.array([xmin_new + i * xstep for i in range(n_new)])
    yy = np.zeros(n_new)
    for i in range(nlin):
        k = i / (nlin - 1)
        yy[i+nzero] = k * y[i]
    for i in range(len(x) - nlin):
        yy[i+nzero+nlin] = y[i+nzero]
    return xx, yy

def SmoothingBySimpleAverage(y, n):
    n2 = int(n / 2);
    ndata = len(y);
    ys = []
    for i in range(0, ndata):
        c = 0;
        ys.append(0.0);
        for k in range(i - n2, i + n2 + 1):
            if k < 0 or k >= ndata:
                continue
            ys[i] += y[k]
            c += 1
        if c > 0:
            ys[i] /= c;
        else:
            ys[i] = y[i]
    return ys;

def SmoothingByPolynomialFit(y, n):
    m = int(n / 2);
    W23 = (4.0 * m * m - 1.0) * (2.0 * m + 3.0) / 3.0;
    w23j = [0.0]*n
    for j in range(-m, m+1):
        w23j[j + m] = (3.0 * m * (m+1.0) - 1.0 - 5.0 * j * j) / W23

    ndata = len(y)
    ys = []
    for i in range(0, ndata):
        c = 0.0;
        ys.append(0.0);
        for j in range(-m, m+1):
            k = i + j
            if k < 0 or k >= ndata:
                continue
            ys[i] += w23j[j+m] * y[k]
            c += w23j[j+m]
        if c > 0:
            ys[i] /= c
        else:
            ys[i] = y[i]
    return ys;


def deconvolute_fft(xRaw, yRaw, xG, yG):
    k = sum(yG)

    print("Deconvolution by FFT")
    n = len(xRaw)
    nlog = int(log(n) / log(2) + 1.0 - 1.0e-5)
    nfft = pow(2, nlog)
    print("  Data number is changed from {} to 2^{} = {} for FFT".format(n, nlog, nfft))
    
    xmin = xRaw[0]
    xstep = xRaw[1] - xmin
    xRawFFT = [xmin + i * xstep for i in range(nfft)]
    yRawFFT = np.insert(yRaw, len(yRaw), np.zeros(nfft - n))
# filterの中心位置の原点からのずれによって、iFFT後の原点がずれる
    yGFFT   = np.insert(yG, len(yG), np.zeros(nfft - len(yG)))
    xminG = xmin + len(xG) / 2 * xstep
    print("  xmin: ", xmin, xminG) 
    xGFFT = [xminG + i * xstep for i in range(nfft)]
#    nadd = int((nfft - len(yG)) / 2)
#    yGFFT   = np.insert(yG, len(yG), np.zeros(nadd))
#    yGFFT   = np.insert(yGFFT, 0, np.zeros(nfft - len(yGFFT)))

    yRawFFTed = np.fft.fft(yRawFFT)
    yGFFTed   = np.fft.fft(yGFFT)
    ycFFTed = yRawFFTed / yGFFTed
    ydeconv = np.fft.ifft(ycFFTed)
    ydeconv = [float(ydeconv[i]) for i in range(len(ydeconv))]

    print("")

    return xGFFT, ydeconv, xRawFFT, yRawFFT, xRawFFT, yGFFT

def deconvolute_deconvolve(xRaw, yRaw, xG, yG):
    k = sum(yG)

    print("Deconvolution by scipy.signal.deconvove")
    print("")
    IDec, remainder = scipy.signal.deconvolve(yRaw, yG)
    IDec *= k
    ndata = len(xRaw)
    nGhalf = int(len(xG) / 2)

    return xRaw[nGhalf:ndata-nGhalf], IDec

def deconvolute_jacobi(xRaw, yRaw, xG, yG, fig, ax):
    global Wa, Grange

    k = sum(yG)

    print("Deconvolution by Jacobi method")
    print("")

    xstep = xRaw[1] - xRaw[0]

    xgmin = min(xRaw)
    xgmax = max(xRaw)

    n = len(xRaw)
    Sg = np.zeros(n)
    for i in range(n):
        for j in range(n):
            Sg[i] += Hij(xstep, Wa, Grange, i, j)
    print("Filter area w.r.t. i: Sg=", Sg[int(n/2)])

    ymax = max([abs(yRaw[i]) for i in range(n)])
    y     = yRaw.copy()
    yPrev = yRaw.copy()

    for it in range(nmaxiter):
        Hx = np.zeros(n)
        print("iter=", it)

        for i in range(n):
            for j in range(n):
                Hx[i] += Hij(xstep, Wa, Grange, i, j) * yPrev[j]
            h = Hij(xstep, Wa, Grange, i, i)
            y[i] = yPrev[i] + (yRaw[i] - Hx[i] / Sg[i]) / h * dump

#        y = SmoothingBySimpleAverage(y, nsmooth)
        y = SmoothingByPolynomialFit(y, nsmooth)
        if zero_correction:
            for i in range(n):
                if y[i] < 0.0:
                    y[i] = 0.0

        ax[0].cla()

        data1 = ax[0].plot(xRaw, yRaw, label = 'raw/initial')
        data1 = ax[0].plot(xRaw, y, label = 'updated')
#        data4 = ax[2].plot(xG, yG, label = 'filter')
        ax[0].set_xlim([xgmin, xgmax])
        ygmax = max([max(xRaw), max(y)])
        ax[0].set_ylim([0.0, ygmax])
#        ax[1].set_xlim([xgmin, xgmax])
#        ax[1].set_ylim([0.0, max(yRaw)])
#        ax[2].set_xlim([xgmin, xgmax])
#        ax[2].set_ylim([0.0, max(yG)])

        ax[0].legend()
#        ax[2].legend()
        plt.tight_layout()
        plt.pause(sleeptime)

        max_err = max([abs(y[i] - yPrev[i]) for i in range(n)])
        rel_err = max_err / ymax
        print("  max error: ", max_err, "  relative error: ", rel_err, "  eps=", eps)
        if max_err / ymax < eps:
            print("Converged at max_err={} ({} relative) < {}".format(max_err, rel_err, eps))
            break
        
        yPrev = y.copy()
    else:
        print("Not converged")

    return xRaw, y

def deconvolute_gauss_seidel(xRaw, yRaw, xG, yG, fig, ax):
    global Wa, Grange

    k = sum(yG)

    print("Deconvolution by Jacobi method")
    print("")

    xstep = xRaw[1] - xRaw[0]

    xgmin = min(xRaw)
    xgmax = max(xRaw)

    n = len(xRaw)
    Sg = np.zeros(n)
    for i in range(n):
        for j in range(n):
            Sg[i] += Hij(xstep, Wa, Grange, i, j)
    print("Filter area w.r.t. i: Sg=", Sg[int(n/2)])

    ymax = max([abs(yRaw[i]) for i in range(n)])
    y     = yRaw.copy()
    yPrev = yRaw.copy()

    for it in range(nmaxiter):
        Hx    = np.zeros(n)
        print("iter=", it)

        for i in range(n):
            for j in range(i):
                Hx[i] += Hij(xstep, Wa, Grange, i, j) * y[j]
            for j in range(i, n):
                Hx[i] += Hij(xstep, Wa, Grange, i, j) * yPrev[j]
            h = Hij(xstep, Wa, Grange, i, i)
            y[i] = yPrev[i] + (yRaw[i] - Hx[i] / Sg[i]) / h * dump

#        y = SmoothingBySimpleAverage(y, nsmooth)
        y = SmoothingByPolynomialFit(y, nsmooth)
        if zero_correction:
            for i in range(n):
                if y[i] < 0.0:
                    y[i] = 0.0

        ax[0].cla()

        data1 = ax[0].plot(xRaw, yRaw, label = 'raw/initial')
        data1 = ax[0].plot(xRaw, y, label = 'updated')
#        data4 = ax[2].plot(xG, yG, label = 'filter')
        ax[0].set_xlim([xgmin, xgmax])
        ygmax = max([max(xRaw), max(y)])
        ax[0].set_ylim([0.0, ygmax])
#        ax[1].set_xlim([xgmin, xgmax])
#        ax[1].set_ylim([0.0, max(yRaw)])
#        ax[2].set_xlim([xgmin, xgmax])
#        ax[2].set_ylim([0.0, max(yG)])

        ax[0].legend()
#        ax[2].legend()
        plt.tight_layout()
        plt.pause(sleeptime)

        max_err = max([abs(y[i] - yPrev[i]) for i in range(n)])
        rel_err = max_err / ymax
        print("  max error: ", max_err, "  relative error: ", rel_err, "  eps=", eps)
        if max_err / ymax < eps:
            print("Converged at max_err={} ({} relative) < {}".format(max_err, rel_err, eps))
            break
        
        yPrev = y.copy()
    else:
        print("Not converged")

    return xRaw, y


#======================
# main
#======================
def main():
    print("infile  : ", infile)
    print("outfile : ", outcsvfile)
    print("mode    : ", mode)
    if mode == 'convolve' or mode == 'fft':
        print("For mode = 'convolve' or 'fft':")
        print("  convmode: ", convmode)
        print("  x range : ", xmin, xmax)
    if mode == 'gs' or mode == 'jacobi':
        print("For mode = 'gs' or 'jacobi':")
        print("  dump=", dump)
        print("  nmaxiter=", nmaxiter)
        print("  eps=", eps)
        print("  nsmooth=", nsmooth)
        print("  zero correction=", zero_correction)
    print("")

    header, xin, yin = read_data(infile, xmin, xmax)
    nindata = len(xin)
    xinmin  = xin[0]
    xinstep = xin[1] - xin[0]
    xinmax  = xinmin + xinstep * (nindata - 1)
    print("Input data")
    print("  ndata = ", nindata)
    print("  x range: {} - {} at {} step".format(xinmin, xinmax, xinstep))
    print("Smooth mode: ", smoothmode)
    print("")

    xRaw  = xin
    yRaw  = yin
    xstep = xinstep
    print("Filter: Wa = {}  Grange = {}".format(Wa, Grange))
    xG, yG = make_wf(Wa, Grange, xstep)
    print("")


    fig = plt.figure(figsize = (15,7))
    ax = []
    ax.append(fig.add_subplot(3, 1, 1))
    ax.append(fig.add_subplot(3, 1, 2))
    ax.append(fig.add_subplot(3, 1, 3))

# make data to be deconvolved
    nw = int(len(xG) / 2)
    if 'average' in smoothmode:
        yRaw = SmoothingBySimpleAverage(yRaw, 31)
    if 'extend' in smoothmode:
        xRaw, yRaw = extend_smooth(xRaw, yRaw, nw*kzero, nw*klin, xstep)
    if 'convolve' in smoothmode:
        xconv, yconv = convolve(xRaw, yRaw, yG, mode = convmode)
    else:
        xconv, yconv = xRaw, yRaw

# deconvolution
    if mode == 'fft':
        xDec, yDec, xRawFFT, yRawFFT, xGFFT, yGFFT = deconvolute_fft(xconv, yconv, xG, yG)
        xconv = xRawFFT
        yconv = yRawFFT
        datafft = ax[2].plot(xGFFT, yGFFT, label = 'WF for FFT')
    elif mode == 'jacobi':
        xDec, yDec = deconvolute_jacobi(xconv, yconv, xG, yG, fig, ax)
        yG = [Gaussian(xRaw[i], 0.0, Wa) for i in range(len(xRaw))]
        datafft = ax[2].plot(xRaw, yG, label = 'WF for Jacobi method')
    elif mode == 'gs':
        xDec, yDec = deconvolute_gauss_seidel(xconv, yconv, xG, yG, fig, ax)
        yG = [Gaussian(xRaw[i], 0.0, Wa) for i in range(len(xRaw))]
        datafft = ax[2].plot(xRaw, yG, label = 'WF for Gauss-Seidel method')
    else:
        xDec, yDec = deconvolute_deconvolve(xconv, yconv, xG, yG)
        datafft = ax[2].plot(xG, yG, label = 'WF for convolve')

    xgmin = min([min(xconv), min(xDec)])
    xgmax = max([max(xconv), max(xDec)])

    ax[0].cla()
    data1 = ax[0].plot(xconv, yconv, label = 'input(convoluted)')
    data3 = ax[0].plot(xDec, yDec, label = 'deconvoluted')
    data1 = ax[1].plot(xRaw, yRaw, label = 'input(raw)')
    data2 = ax[1].plot(xconv, yconv, label = 'input(convoluted)')
    ax[0].set_xlim([xgmin, xgmax])
    ax[0].set_ylim([0.0, max([max(yRaw), max(yDec)])])
    ax[1].set_xlim([xgmin, xgmax])
    ax[1].set_ylim([0.0, max(yRaw)])
    ax[2].set_xlim([xgmin, xgmax])
    ax[2].set_ylim([0.0, max(yG)])
#    ax[1].set_ylim([0.0, max(yRaw)])

    ax[0].legend()
    ax[1].legend()
    ax[2].legend()
    plt.tight_layout()

    plt.pause(0.001)

    print("")
    print("Save deconvoluted data to [{}]".format(outcsvfile))
    savecsv(outcsvfile, ['x', 'y(input)', 'y(deconvoluted)'], [xconv, yconv, yDec])

    print("Press ENTER to exit>>", end = '')
    input()


if __name__ == '__main__':
    usage()
    main()

