Skip to content

Instantly share code, notes, and snippets.

@harrymander
Last active February 26, 2024 20:50
Show Gist options
  • Save harrymander/09dfd9a07d8978e033282a5dc4b1e991 to your computer and use it in GitHub Desktop.
Save harrymander/09dfd9a07d8978e033282a5dc4b1e991 to your computer and use it in GitHub Desktop.
from scipy.fft import rfft, rfftfreq
import numpy as np
def dtft(x, w):
return np.sum(x * np.exp(-1j*w*np.arange(x.size)))
def sinusoid_freq(x, *, fs=1, niter=2, nfft=None):
"""
Find the exact frequency of a sinusoid from its FFT.
Uses method from Guo & Blu, 'Super-Resolving a Frequency Band'
(https://read.nxtbook.com/ieee/signal_processing/signal_processing_nov_2023/super_resolving_a_frequency_b.html)
Args:
x: real array-like, time-domain signal of the sinusoid
fs: sample rate
niter: number of applications of the method
nfft: number of points to use in FFT; if None, uses length of x
Returns: the frequency of x
"""
nfft = nfft or x.size
imax = abs(rfft(x, n=nfft)).argmax()
w0 = 2 * np.pi * imax / nfft
for _ in range(niter):
w1 = w0 - np.pi / nfft
w2 = w0 + np.pi / nfft
X1 = abs(dtft(x, w1))
X2 = abs(dtft(x, w2))
w0 = (w1 + w2)/2 + 2*np.arctan(
np.tan((w2 - w1)/4) * (X2 - X1) / (X2 + X1)
)
return fs * w0 / 2 / np.pi
if __name__ == "__main__":
# Example
# Create a 68.2 Hz sinusoid sampled at 1.02 kHz
fs = 1020
f0 = 68.2
t = np.arange(0, 2, 1/fs)
x = np.sin(2*np.pi*f0*t)
# Add some noise
x += 0.1*np.random.randn(x.size)
nfft = x.size // 3
# Estimate frequency with just max FFT
X = abs(rfft(x, n=nfft))
freq = rfftfreq(nfft)
print(f'Max FFT = {fs * freq[X.argmax()]} Hz')
# Estimate using the method
print(f'Exact freq = {sinusoid_freq(x, fs=fs, nfft=nfft)} Hz')
import matplotlib.pyplot as plt
f0_norm = f0 / fs
dtft_freq = np.linspace(
f0_norm - 2/nfft,
f0_norm + 2/nfft,
1000
)
dtft_X = np.abs([dtft(x, 2*np.pi*f) for f in dtft_freq])
plt.plot(dtft_freq, dtft_X, '--', label='DTFT')
dtft_mask = (freq >= dtft_freq[0]) & (freq <= dtft_freq[-1])
plt.stem(
freq[dtft_mask], X[dtft_mask], 'x',
basefmt=' ', linefmt='r',
label='DFT'
)
plt.ylim(0, plt.ylim()[-1])
plt.margins(0)
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment