import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.fft import fftn, ifftn, fftshift, ifftshift

def interpolate_3d_periodic_data_fft(data, interp_factor=(2, 2, 2)):
    """
    3次元の等間隔周期データをFFTで補間します。

    Args:
        data (np.ndarray): 補間したい3次元の周期データ。形状は (Nx, Ny, Nz)。
        interp_factor (tuple): 各次元の補間倍率。(interp_x, interp_y, interp_z)

    Returns:
        np.ndarray: 補間された3次元データ。
    """
    if not (isinstance(data, np.ndarray) and data.ndim == 3):
        raise ValueError("入力データは3次元のNumPy配列である必要があります。")
    if not (isinstance(interp_factor, tuple) and len(interp_factor) == 3 and all(isinstance(f, int) and f >= 1 for f in interp_factor)):
        raise ValueError("interp_factorは各次元の整数補間倍率を示す3つの要素を持つタプルである必要があります。")

    Nx, Ny, Nz = data.shape
    interp_Nx, interp_Ny, interp_Nz = interp_factor

    # 1. FFTを適用して周波数領域のデータを得る
    F_data = fftshift(fftn(data))

    # 2. 周波数領域でのゼロパディング
    # 新しいサイズを計算
    new_Nx = Nx * interp_Nx
    new_Ny = Ny * interp_Ny
    new_Nz = Nz * interp_Nz

    # 新しいゼロパディングされた配列を初期化
    F_data_padded = np.zeros((new_Nx, new_Ny, new_Nz), dtype=F_data.dtype)

    # 元の周波数成分を新しい配列の中央にコピー
    start_x = (new_Nx - Nx) // 2
    end_x = start_x + Nx
    start_y = (new_Ny - Ny) // 2
    end_y = start_y + Ny
    start_z = (new_Nz - Nz) // 2
    end_z = start_z + Nz

    F_data_padded[start_x:end_x, start_y:end_y, start_z:end_z] = F_data

    # 3. 逆FFTを適用して補間されたデータを得る
    interpolated_data = np.real(ifftn(ifftshift(F_data_padded))) * (new_Nx * new_Ny * new_Nz) / (Nx * Ny * Nz)

    return interpolated_data

# --- 使用例 ---
if __name__ == "__main__":
    # 1. サンプル3次元データの作成
    Nx_orig, Ny_orig, Nz_orig = 10, 10, 10
    x_orig = np.linspace(0, 2 * np.pi, Nx_orig, endpoint=False)
    y_orig = np.linspace(0, 2 * np.pi, Ny_orig, endpoint=False)
    z_orig = np.linspace(0, 2 * np.pi, Nz_orig, endpoint=False)
    X_orig, Y_orig, Z_orig = np.meshgrid(x_orig, y_orig, z_orig, indexing='ij')

    original_data = np.sin(X_orig * 2) + np.cos(Y_orig * 3) + np.sin(Z_orig * 1.5)

    print(f"元のデータの形状: {original_data.shape}")

    # 2. FFTによる補間
    interp_factor = (4, 4, 4) # 各次元を4倍に補間
    interpolated_data = interpolate_3d_periodic_data_fft(original_data, interp_factor)

    print(f"補間後のデータの形状: {interpolated_data.shape}")

    # **ここから修正点**
    # interpolated_data の形状から新しい次元数を取得する
    new_Nx, new_Ny, new_Nz = interpolated_data.shape
    # **修正点ここまで**

    # 3. 結果の可視化 (スライス表示)
    slice_idx_orig_x = Nx_orig // 2
    slice_idx_interp_x = new_Nx // 2 # 補間後のデータに対応するインデックス

    fig = plt.figure(figsize=(14, 7))

    ax1 = fig.add_subplot(121)
    c1 = ax1.imshow(original_data[slice_idx_orig_x, :, :], origin='lower', cmap='viridis',
                    extent=[0, 2 * np.pi, 0, 2 * np.pi])
    fig.colorbar(c1, ax=ax1, fraction=0.046, pad=0.04)
    ax1.set_title(f'Original Data (X-slice at {slice_idx_orig_x})')
    ax1.set_xlabel('Y-axis')
    ax1.set_ylabel('Z-axis')

    ax2 = fig.add_subplot(122)
    c2 = ax2.imshow(interpolated_data[slice_idx_interp_x, :, :], origin='lower', cmap='viridis',
                    extent=[0, 2 * np.pi, 0, 2 * np.pi])
    fig.colorbar(c2, ax=ax2, fraction=0.046, pad=0.04)
    ax2.set_title(f'Interpolated Data (X-slice at {slice_idx_interp_x})')
    ax2.set_xlabel('Y-axis')
    ax2.set_ylabel('Z-axis')

    plt.tight_layout()
    plt.show()

    # オプション：3D可視化（データ点が多すぎると描画が重くなる可能性があります）
    fig_3d = plt.figure(figsize=(12, 6))

    ax_orig_3d = fig_3d.add_subplot(121, projection='3d')
    sc_orig = ax_orig_3d.scatter(X_orig.flatten(), Y_orig.flatten(), Z_orig.flatten(),
                                 c=original_data.flatten(), cmap='viridis', s=20)
    fig_3d.colorbar(sc_orig, ax=ax_orig_3d, shrink=0.5, aspect=5)
    ax_orig_3d.set_title('Original Data Points')
    ax_orig_3d.set_xlabel('X')
    ax_orig_3d.set_ylabel('Y')
    ax_orig_3d.set_zlabel('Z')

    ax_interp_3d = fig_3d.add_subplot(122, projection='3d')
    # **ここから修正点**
    # new_Nx, new_Ny, new_Nz がここで定義されているため、これらを使用できます
    x_interp = np.linspace(0, 2 * np.pi, new_Nx, endpoint=False)
    y_interp = np.linspace(0, 2 * np.pi, new_Ny, endpoint=False)
    # **修正点ここまで**
    X_interp_slice, Y_interp_slice = np.meshgrid(x_interp, y_interp, indexing='ij')

    z_slice_idx = new_Nz // 2 # 補間後のデータに対応するインデックス
    surf_interp = ax_interp_3d.plot_surface(X_interp_slice, Y_interp_slice,
                                            interpolated_data[:, :, z_slice_idx],
                                            cmap='viridis', edgecolor='none')
    fig_3d.colorbar(surf_interp, ax=ax_interp_3d, shrink=0.5, aspect=5)
    ax_interp_3d.set_title(f'Interpolated Data (Z-slice at {z_slice_idx})')
    ax_interp_3d.set_xlabel('X')
    ax_interp_3d.set_ylabel('Y')
    ax_interp_3d.set_zlabel('Value')

    plt.tight_layout()
    plt.show()