spectrum_.FFT_interpolate のソースコード

"""
FFTを用いた周期関数の補間およびデータ処理を行うモジュール。

概要:
    入力データに対してFFTを適用し、周波数領域でゼロパディングを行うことでデータ解像度を向上させます。
    逆FFTにより、滑らかな周期関数の補間を実現します。

詳細説明:
    1. Excelファイルまたは内部生成関数から離散データ(k, E(k))を取得します。
    2. 反転対称性(鏡像)が必要な場合、k < 0 の領域を自動生成して結合します。
    3. FFTを実行し、高周波成分にゼロを挿入(ゼロパディング)して逆FFTを行うことで、
       データ点数を増やした高解像度なプロファイルを生成します。
    4. 結果をExcelファイルに出力し、元データ、補間結果、および厳密解を比較プロットします。

関連リンク: :doc:`FFT_interpolate_usage`
"""

import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os

#===================
# デフォルトパラメータ
#===================
MODE_DEF = "fft"
INFILE_DEF = "interpolate_fft_test.xlsx"
DO_MIRROR_DEF = False
INTERP_FACTOR = 10
KRANGE_DEF = [-0.5, 0.5]
N_SAMPLES_DEF = 10

[ドキュメント] def periodic_function(k): """ テスト用のサンプル周期関数。 :param k: float or numpy.ndarray: 入力値。 :returns: 対応する関数の値。 """ return -np.cos(2.0 * np.pi * k) * (1.0 + 5.0 * k**2)
[ドキュメント] def read_data(infile, do_mirror=False): """ Excelファイルからバンド構造データ(k, E(k))を読み込み、必要に応じて鏡像処理を行います。 :param infile: str: 入力ファイルパス。 :param do_mirror: bool: 反転対称データ E(-k) を追加するかどうか。 :returns: tuple: (x, y, xe, ye) のデータリスト。 """ try: df = pd.read_excel(infile) # NaNを除外 x = df['k'].dropna().values.tolist() y = df['E(k)'].dropna().values.tolist() # オプション列の読み込み(存在しない場合は空リスト) xe = df['k,e'].dropna().values.tolist() if 'k,e' in df.columns else [] ye = df['E(k),e'].dropna().values.tolist() if 'E(k),e' in df.columns else [] if do_mirror: # 負の領域を反転コピーして生成 _x = [-x[i] for i in range(len(x) - 1, 0, -1)] + x _y = [y[i] for i in range(len(y) - 1, 0, -1)] + y _xe = ([-xe[i] for i in range(len(xe) - 1, 0, -1)] + xe) if xe else [] _ye = ([-ye[i] for i in range(len(ye) - 1, 0, -1)] + ye) if ye else [] return _x, _y, _xe, _ye return x, y, xe, ye except Exception as e: print(f"Error reading {infile}: {e}") return [], [], [], []
[ドキュメント] def main(): """ メインの実行ルーチン。データの読み込み、FFT処理、補間、保存、プロットを制御します。 """ # 引数パース argv = sys.argv infile = argv[1] if len(argv) > 1 else INFILE_DEF do_mirror = bool(int(argv[2])) if len(argv) > 2 else DO_MIRROR_DEF mode = argv[3] if len(argv) > 3 else MODE_DEF print(f"\nInput file: {infile}") print(f"Add mirror data: {do_mirror}") # データの準備 if os.path.exists(infile): print(f"Reading data from [{infile}]") x_raw, y_raw, xe, ye = read_data(infile, do_mirror) if not x_raw: return n = len(x_raw) # 周期性を保つため最後の1点を除いて処理 x = np.array(x_raw[:n-1]) y = np.array(y_raw[:n-1]) k_min, k_max = min(x_raw), max(x_raw) else: print("Input file not found. Generating sample data.") k_min, k_max = KRANGE_DEF x = np.linspace(k_min, k_max, N_SAMPLES_DEF, endpoint=False) y = periodic_function(x) xe = np.linspace(k_min, k_max, N_SAMPLES_DEF * INTERP_FACTOR, endpoint=False) ye = periodic_function(xe) n_samples = len(x) n_interp = n_samples * INTERP_FACTOR # STEP 2: FFTの計算 y_fft = np.fft.fft(y) dt = (x[-1] - x[0]) / (len(x) - 1) if len(x) > 1 else 1.0 freq = np.fft.fftfreq(len(y), d=dt) y_fft_centered = np.fft.fftshift(y_fft) freq_centered = np.fft.fftshift(freq) # STEP 3: ゼロパディングによる解像度向上 y_fft_padded = np.zeros(n_interp, dtype=complex) half = n_samples // 2 y_fft_padded[:half] = y_fft[:half] y_fft_padded[-half:] = y_fft[-half:] # IFFTによる補間 x_interp = np.linspace(k_min, k_max, n_interp, endpoint=False) y_interp = np.fft.ifft(y_fft_padded) * INTERP_FACTOR # --- Excel出力 --- file_body = os.path.splitext(os.path.basename(infile))[0] output_excel = f"{file_body}-fft-interpolated.xlsx" pd.DataFrame({ 'Interpolated X': x_interp, 'Interpolated Y': y_interp.real }).to_excel(output_excel, index=False) print(f"Interpolated data saved to '{output_excel}'") # --- プロット設定 --- plt.rcParams["font.size"] = 14 if mode == "fft": fig, axes = plt.subplots(1, 1, figsize=(8, 6)) axes = [axes] else: fig, axes = plt.subplots(2, 1, figsize=(10, 8)) # 時間(座標)領域プロット ax = axes[0] ax.plot(x, y, 'o', label='Input data', alpha=0.6) ax.plot(x_interp, y_interp.real, '-', label='FFT Interpolated', marker='x', markersize=3) if len(xe) > 0: ax.plot(xe, ye, '--', label='Exact', alpha=0.5) ax.set_xlabel('k (Lattice Coordinate)') ax.set_ylabel('E(k) (Energy)') ax.legend(fontsize=10) ax.grid(True, alpha=0.3) # 周波数(スペクトル)領域プロット if mode == 'plot': ax = axes[1] ax.plot(freq_centered, np.abs(y_fft_centered), label="Magnitude", marker="o", markersize=2) ax.set_xlabel("Frequency") ax.set_ylabel("Amplitude") ax.legend(fontsize=10) ax.grid(True, alpha=0.3) plt.tight_layout() plt.show(block=False) input("Press ENTER to exit>>")
if __name__ == "__main__": main()