interpolate3d_fft.py ダウンロード/コピー
interpolate3d_fft.py
interpolate3d_fft.py
1"""
23次元周期データのFFTを用いた補間処理を提供します。
3
4このモジュールは、フーリエ変換とゼロパディングを利用して、3次元空間に周期的に配置された離散データを高解像度化するための関数を含みます。
5主に科学技術計算や画像処理の分野で、データの補間やアップサンプリングに利用されます。
6
7:doc:`interpolate3d_fft_usage`
8"""
9import numpy as np
10import matplotlib.pyplot as plt
11from mpl_toolkits.mplot3d import Axes3D
12from scipy.fft import fftn, ifftn, fftshift, ifftshift
13
14def interpolate_3d_periodic_data_fft(data, interp_factor=(2, 2, 2)):
15 """
16 3次元の等間隔周期データをFFT(高速フーリエ変換)で補間します。
17
18 この関数は、与えられた3次元データをフーリエ変換(FFT)により周波数領域に変換します。
19 周波数領域において、元のデータが占める領域以外の部分をゼロで埋める(ゼロパディング)ことで、
20 より高周波数成分がゼロであると仮定し、実空間での高解像度補間を実現します。
21 ゼロパディングされた周波数領域データに逆フーリエ変換(IFFT)を適用することで、
22 指定された補間倍率でサンプリングされた高解像度のデータが得られます。
23 補間は周期境界条件を仮定して行われます。
24 `fftshift` と `ifftshift` を使用して、FFTのDC成分が配列の中心にくるように処理することで、
25 ゼロパディングが適切に行われます。
26
27 :param data: np.ndarray
28 補間したい3次元の周期データ。形状は `(Nx, Ny, Nz)` である必要があります。
29 要素のデータ型は実数または複素数に対応します。
30 :param interp_factor: tuple
31 各次元 `(x, y, z)` の補間倍率を示す3つの整数 `(interp_x, interp_y, interp_z)` のタプル。
32 各要素は1以上の整数である必要があります。`1` を指定した場合、その次元は補間されません。
33 :returns: np.ndarray
34 補間された3次元データ。元のデータ型が実数であれば実数データが返されます。
35 この実装では最終的に `np.real` を適用しているため、実数部のみが返されます。
36 形状は `(Nx * interp_x, Ny * interp_y, Nz * interp_z)` となります。
37 :raises ValueError:
38 入力データが3次元のNumPy配列でない場合、または `interp_factor` が適切な形式でない場合に発生します。
39 """
40 if not (isinstance(data, np.ndarray) and data.ndim == 3):
41 raise ValueError("入力データは3次元のNumPy配列である必要があります。")
42 if not (isinstance(interp_factor, tuple) and len(interp_factor) == 3 and all(isinstance(f, int) and f >= 1 for f in interp_factor)):
43 raise ValueError("interp_factorは各次元の整数補間倍率を示す3つの要素を持つタプルである必要があります。")
44
45 Nx, Ny, Nz = data.shape
46 interp_Nx, interp_Ny, interp_Nz = interp_factor
47
48 # 1. FFTを適用して周波数領域のデータを得る
49 F_data = fftshift(fftn(data))
50
51 # 2. 周波数領域でのゼロパディング
52 # 新しいサイズを計算
53 new_Nx = Nx * interp_Nx
54 new_Ny = Ny * interp_Ny
55 new_Nz = Nz * interp_Nz
56
57 # 新しいゼロパディングされた配列を初期化
58 F_data_padded = np.zeros((new_Nx, new_Ny, new_Nz), dtype=F_data.dtype)
59
60 # 元の周波数成分を新しい配列の中央にコピー
61 start_x = (new_Nx - Nx) // 2
62 end_x = start_x + Nx
63 start_y = (new_Ny - Ny) // 2
64 end_y = start_y + Ny
65 start_z = (new_Nz - Nz) // 2
66 end_z = start_z + Nz
67
68 F_data_padded[start_x:end_x, start_y:end_y, start_z:end_z] = F_data
69
70 # 3. 逆FFTを適用して補間されたデータを得る
71 interpolated_data = np.real(ifftn(ifftshift(F_data_padded))) * (new_Nx * new_Ny * new_Nz) / (Nx * Ny * Nz)
72
73 return interpolated_data
74
75# --- 使用例 ---
76if __name__ == "__main__":
77 # 1. サンプル3次元データの作成
78 Nx_orig, Ny_orig, Nz_orig = 10, 10, 10
79 x_orig = np.linspace(0, 2 * np.pi, Nx_orig, endpoint=False)
80 y_orig = np.linspace(0, 2 * np.pi, Ny_orig, endpoint=False)
81 z_orig = np.linspace(0, 2 * np.pi, Nz_orig, endpoint=False)
82 X_orig, Y_orig, Z_orig = np.meshgrid(x_orig, y_orig, z_orig, indexing='ij')
83
84 original_data = np.sin(X_orig * 2) + np.cos(Y_orig * 3) + np.sin(Z_orig * 1.5)
85
86 print(f"元のデータの形状: {original_data.shape}")
87
88 # 2. FFTによる補間
89 interp_factor = (4, 4, 4) # 各次元を4倍に補間
90 interpolated_data = interpolate_3d_periodic_data_fft(original_data, interp_factor)
91
92 print(f"補間後のデータの形状: {interpolated_data.shape}")
93
94 # **ここから修正点**
95 # interpolated_data の形状から新しい次元数を取得する
96 new_Nx, new_Ny, new_Nz = interpolated_data.shape
97 # **修正点ここまで**
98
99 # 3. 結果の可視化 (スライス表示)
100 slice_idx_orig_x = Nx_orig // 2
101 slice_idx_interp_x = new_Nx // 2 # 補間後のデータに対応するインデックス
102
103 fig = plt.figure(figsize=(14, 7))
104
105 ax1 = fig.add_subplot(121)
106 c1 = ax1.imshow(original_data[slice_idx_orig_x, :, :], origin='lower', cmap='viridis',
107 extent=[0, 2 * np.pi, 0, 2 * np.pi])
108 fig.colorbar(c1, ax=ax1, fraction=0.046, pad=0.04)
109 ax1.set_title(f'Original Data (X-slice at {slice_idx_orig_x})')
110 ax1.set_xlabel('Y-axis')
111 ax1.set_ylabel('Z-axis')
112
113 ax2 = fig.add_subplot(122)
114 c2 = ax2.imshow(interpolated_data[slice_idx_interp_x, :, :], origin='lower', cmap='viridis',
115 extent=[0, 2 * np.pi, 0, 2 * np.pi])
116 fig.colorbar(c2, ax=ax2, fraction=0.046, pad=0.04)
117 ax2.set_title(f'Interpolated Data (X-slice at {slice_idx_interp_x})')
118 ax2.set_xlabel('Y-axis')
119 ax2.set_ylabel('Z-axis')
120
121 plt.tight_layout()
122 plt.show()
123
124 # オプション:3D可視化(データ点が多すぎると描画が重くなる可能性があります)
125 fig_3d = plt.figure(figsize=(12, 6))
126
127 ax_orig_3d = fig_3d.add_subplot(121, projection='3d')
128 sc_orig = ax_orig_3d.scatter(X_orig.flatten(), Y_orig.flatten(), Z_orig.flatten(),
129 c=original_data.flatten(), cmap='viridis', s=20)
130 fig_3d.colorbar(sc_orig, ax=ax_orig_3d, shrink=0.5, aspect=5)
131 ax_orig_3d.set_title('Original Data Points')
132 ax_orig_3d.set_xlabel('X')
133 ax_orig_3d.set_ylabel('Y')
134 ax_orig_3d.set_zlabel('Z')
135
136 ax_interp_3d = fig_3d.add_subplot(122, projection='3d')
137 # **ここから修正点**
138 # new_Nx, new_Ny, new_Nz がここで定義されているため、これらを使用できます
139 x_interp = np.linspace(0, 2 * np.pi, new_Nx, endpoint=False)
140 y_interp = np.linspace(0, 2 * np.pi, new_Ny, endpoint=False)
141 # **修正点ここまで**
142 X_interp_slice, Y_interp_slice = np.meshgrid(x_interp, y_interp, indexing='ij')
143
144 z_slice_idx = new_Nz // 2 # 補間後のデータに対応するインデックス
145 surf_interp = ax_interp_3d.plot_surface(X_interp_slice, Y_interp_slice,
146 interpolated_data[:, :, z_slice_idx],
147 cmap='viridis', edgecolor='none')
148 fig_3d.colorbar(surf_interp, ax=ax_interp_3d, shrink=0.5, aspect=5)
149 ax_interp_3d.set_title(f'Interpolated Data (Z-slice at {z_slice_idx})')
150 ax_interp_3d.set_xlabel('X')
151 ax_interp_3d.set_ylabel('Y')
152 ax_interp_3d.set_zlabel('Value')
153
154 plt.tight_layout()
155 plt.show()