"""
FFTを用いた周期関数の補間およびデータ処理を行うモジュール。
概要:
入力データに対してFFTを適用し、周波数領域でゼロパディングを行うことでデータ解像度を向上させます。
逆FFTにより、滑らかな周期関数の補間を実現します。
詳細説明:
1. Excelファイルまたは内部生成関数から離散データ(k, E(k))を取得します。
2. 反転対称性(鏡像)が必要な場合、k < 0 の領域を自動生成して結合します。
3. FFTを実行し、高周波成分にゼロを挿入(ゼロパディング)して逆FFTを行うことで、
データ点数を増やした高解像度なプロファイルを生成します。
4. 結果をExcelファイルに出力し、元データ、補間結果、および厳密解を比較プロットします。
関連リンク: :doc:`FFT_interpolate_usage`
"""
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
#===================
# デフォルトパラメータ
#===================
MODE_DEF = "fft"
INFILE_DEF = "interpolate_fft_test.xlsx"
DO_MIRROR_DEF = False
INTERP_FACTOR = 10
KRANGE_DEF = [-0.5, 0.5]
N_SAMPLES_DEF = 10
[ドキュメント]
def periodic_function(k):
"""
テスト用のサンプル周期関数。
:param k: float or numpy.ndarray: 入力値。
:returns: 対応する関数の値。
"""
return -np.cos(2.0 * np.pi * k) * (1.0 + 5.0 * k**2)
[ドキュメント]
def read_data(infile, do_mirror=False):
"""
Excelファイルからバンド構造データ(k, E(k))を読み込み、必要に応じて鏡像処理を行います。
:param infile: str: 入力ファイルパス。
:param do_mirror: bool: 反転対称データ E(-k) を追加するかどうか。
:returns: tuple: (x, y, xe, ye) のデータリスト。
"""
try:
df = pd.read_excel(infile)
# NaNを除外
x = df['k'].dropna().values.tolist()
y = df['E(k)'].dropna().values.tolist()
# オプション列の読み込み(存在しない場合は空リスト)
xe = df['k,e'].dropna().values.tolist() if 'k,e' in df.columns else []
ye = df['E(k),e'].dropna().values.tolist() if 'E(k),e' in df.columns else []
if do_mirror:
# 負の領域を反転コピーして生成
_x = [-x[i] for i in range(len(x) - 1, 0, -1)] + x
_y = [y[i] for i in range(len(y) - 1, 0, -1)] + y
_xe = ([-xe[i] for i in range(len(xe) - 1, 0, -1)] + xe) if xe else []
_ye = ([-ye[i] for i in range(len(ye) - 1, 0, -1)] + ye) if ye else []
return _x, _y, _xe, _ye
return x, y, xe, ye
except Exception as e:
print(f"Error reading {infile}: {e}")
return [], [], [], []
[ドキュメント]
def main():
"""
メインの実行ルーチン。データの読み込み、FFT処理、補間、保存、プロットを制御します。
"""
# 引数パース
argv = sys.argv
infile = argv[1] if len(argv) > 1 else INFILE_DEF
do_mirror = bool(int(argv[2])) if len(argv) > 2 else DO_MIRROR_DEF
mode = argv[3] if len(argv) > 3 else MODE_DEF
print(f"\nInput file: {infile}")
print(f"Add mirror data: {do_mirror}")
# データの準備
if os.path.exists(infile):
print(f"Reading data from [{infile}]")
x_raw, y_raw, xe, ye = read_data(infile, do_mirror)
if not x_raw: return
n = len(x_raw)
# 周期性を保つため最後の1点を除いて処理
x = np.array(x_raw[:n-1])
y = np.array(y_raw[:n-1])
k_min, k_max = min(x_raw), max(x_raw)
else:
print("Input file not found. Generating sample data.")
k_min, k_max = KRANGE_DEF
x = np.linspace(k_min, k_max, N_SAMPLES_DEF, endpoint=False)
y = periodic_function(x)
xe = np.linspace(k_min, k_max, N_SAMPLES_DEF * INTERP_FACTOR, endpoint=False)
ye = periodic_function(xe)
n_samples = len(x)
n_interp = n_samples * INTERP_FACTOR
# STEP 2: FFTの計算
y_fft = np.fft.fft(y)
dt = (x[-1] - x[0]) / (len(x) - 1) if len(x) > 1 else 1.0
freq = np.fft.fftfreq(len(y), d=dt)
y_fft_centered = np.fft.fftshift(y_fft)
freq_centered = np.fft.fftshift(freq)
# STEP 3: ゼロパディングによる解像度向上
y_fft_padded = np.zeros(n_interp, dtype=complex)
half = n_samples // 2
y_fft_padded[:half] = y_fft[:half]
y_fft_padded[-half:] = y_fft[-half:]
# IFFTによる補間
x_interp = np.linspace(k_min, k_max, n_interp, endpoint=False)
y_interp = np.fft.ifft(y_fft_padded) * INTERP_FACTOR
# --- Excel出力 ---
file_body = os.path.splitext(os.path.basename(infile))[0]
output_excel = f"{file_body}-fft-interpolated.xlsx"
pd.DataFrame({
'Interpolated X': x_interp,
'Interpolated Y': y_interp.real
}).to_excel(output_excel, index=False)
print(f"Interpolated data saved to '{output_excel}'")
# --- プロット設定 ---
plt.rcParams["font.size"] = 14
if mode == "fft":
fig, axes = plt.subplots(1, 1, figsize=(8, 6))
axes = [axes]
else:
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
# 時間(座標)領域プロット
ax = axes[0]
ax.plot(x, y, 'o', label='Input data', alpha=0.6)
ax.plot(x_interp, y_interp.real, '-', label='FFT Interpolated', marker='x', markersize=3)
if len(xe) > 0:
ax.plot(xe, ye, '--', label='Exact', alpha=0.5)
ax.set_xlabel('k (Lattice Coordinate)')
ax.set_ylabel('E(k) (Energy)')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
# 周波数(スペクトル)領域プロット
if mode == 'plot':
ax = axes[1]
ax.plot(freq_centered, np.abs(y_fft_centered), label="Magnitude", marker="o", markersize=2)
ax.set_xlabel("Frequency")
ax.set_ylabel("Amplitude")
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show(block=False)
input("Press ENTER to exit>>")
if __name__ == "__main__":
main()