# https://watlab-blog.com/2020/09/26/fft-find-peaks/

import numpy as np
from scipy import signal
from scipy import fftpack
from matplotlib import pyplot as plt
 
# フーリエ変換して複素スペクトル、振幅、周波数を出力する関数
def fft(y, dt):
    spec = fftpack.fft(y)                                           # フーリエスペクトル（複素数）
    freq = fftpack.fftfreq(len(spec), d=dt)                         # 周波数軸を生成
    amp = np.sqrt(spec.real ** 2 + spec.imag ** 2) / (len(spec)/2)  # 振幅を計算
    return spec, amp, freq
 
# 波形(x, y)からn個のピークを幅wで検出する関数(xは0から始まる仕様）
def findpeaks(x, y, n, w):
    index_all = list(signal.argrelmax(y, order=w))                  # scipyのピーク検出
    index = []                                                      # ピーク指標の空リスト
    peaks = []                                                      # ピーク値の空リスト
 
    # n個分のピーク情報(指標、値）を格納
    for i in range(n):
        index.append(index_all[0][i])
        peaks.append(y[index_all[0][i]])
    index = np.array(index) * x[1]                                  # xの分解能x[1]をかけて指標を物理軸に変換
    return index, peaks
 
# サンプル波形を生成
dt = 0.0001
t = np.arange(0, 10, dt)
noise = 1 * np.random.normal(loc=0, scale=1, size=len(t))
y = signal.sawtooth(2 * np.pi * 1 * t) + noise
 
# フーリエ変換する
spec, amp, freq = fft(y, dt)
 
# ピーク検出する
index, peaks = findpeaks(freq, amp, n=100, w=6)
 
# ここからグラフ描画-------------------------------------
# フォントの種類とサイズを設定する。
plt.rcParams['font.size'] = 14
plt.rcParams['font.family'] = 'Times New Roman'
 
# 目盛を内側にする。
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
 
# グラフの上下左右に目盛線を付ける。
fig = plt.figure(figsize=(8, 5))
ax1 = fig.add_subplot(211)
ax1.yaxis.set_ticks_position('both')
ax1.xaxis.set_ticks_position('both')
ax2 = fig.add_subplot(212)
ax2.yaxis.set_ticks_position('both')
ax2.xaxis.set_ticks_position('both')
 
# 軸のラベルを設定する。
ax1.set_xlabel('Time [s]')
ax1.set_ylabel('Amp.')
ax2.set_xlabel('Frequency [Hz]')
ax2.set_ylabel('Amp.')
 
# スケール設定。
ax2.set_xlim(0, 10)
ax2.set_yscale('log')
 
# データプロットの準備とともに、ラベルと線の太さ、凡例の設置を行う。
ax1.plot(t, y, label='sample', lw=1)
ax1.legend()
ax2.plot(freq[:int(len(freq)/2)], amp[:int(len(freq)/2)], label='sample', lw=1)
ax2.scatter(index, peaks, label='peaks', color='red')
ax2.legend()
 
# レイアウト設定
fig.tight_layout()
 
# グラフを表示する。
plt.show()
plt.close()
# ---------------------------------------------------