import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


mode = "fft"
#mode = "plot"


infile = "interpolate_fft_test.xlsx"
# Electronic band structure E(k) has an inversion symmetry
# so it is enough if a half of E(k) in the first BZ is provided.
# Instead E(-k) data must be added to E(k).
# Use do_mirror = True if a half is E(k) is provided.
do_mirror = False


krange = [-0.5, 0.5]
n_samples = 10
interp_factor = 10
n_interp = n_samples * interp_factor


argv = sys.argv
narg = len(argv)
if narg > 1: infile = argv[1] 
if narg > 2: do_mirror = int(argv[2])
if narg > 3: mode = argv[3]


def periodic_function(k):
    return -np.cos(2.0 * np.pi * k) * (1.0 + 5.0 * k**2)
#    return 0.1 * k**4 + k * k
#    return np.sin(k) + 0.5 * np.sin(2*k)

def read_data(infile, do_mirror = False):
    df = pd.read_excel(infile)
    x = df['k'].values.tolist()
    y = df['E(k)'].values.tolist()
    #xからnanを除外
    x = [i for i in x if str(i) != 'nan']
    y = [i for i in y if str(i) != 'nan']
    xe = df['k,e'].values.tolist()
    ye = df['E(k),e'].values.tolist()
    if do_mirror:
        _x = [-x[i] for i in range(len(x) - 1, 0, -1)]
        _x.extend(x)
        _y = [y[i] for i in range(len(y) - 1, 0, -1)]
        _y.extend(y)
        _xe = [-xe[i] for i in range(len(xe) - 1, 0, -1)]
        _xe.extend(xe)
        _ye = [ye[i] for i in range(len(ye) - 1, 0, -1)]
        _ye.extend(ye)
        return _x, _y, _xe, _ye
    else:
        return x, y, xe, ye


print()
print(f"Input file: {infile}")
print(f"Add mirror E(k) data: {do_mirror}")

if infile:
    print()
    print(f"Read [{infile}]")
    x, y, xe, ye = read_data(infile, do_mirror)
    print("x=", x)
    print("y=", y)
    krange[0] = min(x)
    krange[1] = max(x)
    n = len(x)
    x = x[:n-1]
    y = y[:n-1]
    n_samples = len(x)
    n_interp = n_samples * interp_factor
else:
    print()
    print(f"Generate samples")
    x = np.linspace(krange[0], krange[1], n_samples, endpoint=False)  # sample points
    y = periodic_function(x)  # sampled values
    xe = np.linspace(krange[0], krange[1], n_interp, endpoint=False)
    ye = periodic_function(xe)

# Step 2: Compute the FFT
y_fft = np.fft.fft(y)

# Step 3: Pad zeros to the FTed data for interpolation
y_fft_padded = np.zeros(n_interp, dtype=complex)
y_fft_padded[:n_samples//2] = y_fft[:n_samples//2]
y_fft_padded[-n_samples//2:] = y_fft[-n_samples//2:]

x_interp = np.linspace(krange[0], krange[1], n_interp, endpoint=False)
y_interp = np.fft.ifft(y_fft_padded) * interp_factor


# Plot the original and interpolated functions
fontsize = 16

if mode == "fft":
    fig, axes = plt.subplots(1, 1, figsize=(10, 8))
    axes = [axes]
else:
    fig, axes = plt.subplots(2, 1, figsize=(10, 8))

#目盛りのフォントサイズを16に設定
plt.rcParams["font.size"] = fontsize

ax = axes[0]
ax.plot(x, y, 'o', label = 'input data', markersize = 6)
ax.plot(x_interp, y_interp.real, '-', label = 'interpolated', marker = 'x', markersize = 3)
ax.plot(xe, ye, '-', label='exact')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
#ax.title('Interpolation of a periodic function using FFT')

'''
if mode == 'plot':
    ax = axes[1]
    ax.plot(freqf1, Ampf1, marker = "x")
    ax.plot(freq1, Amp1, marker = "o")
    ax.set_xlabel("Frequency (Hz)", fontsize = fontsize)
    ax.set_ylabel("Amplitude", fontsize = fontsize)
    ax.set_xlim(0, ffmax)
    ylim = ax.get_ylim()
    ax.vlines(freq1,  *ylim, color = "red",  linestyles = "dashed", linewidth = 1.0)
    ax.vlines(freqf1, *ylim, color = "blue", linestyles = "dotted", linewidth = 0.6)
'''


plt.pause(1.0e-4)
input(">>")
