Created
April 29, 2020 15:02
-
-
Save yanniskatsaros/6520ac2ac66e16d8c8534699c42d272c to your computer and use it in GitHub Desktop.
Convenience class for simplifying the process of computing FFT and IFFT. Provides helpful interactive diagnostic and trend plots using Plotly.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import plotly.graph_objects as go | |
from scipy.fft import fft, ifft, fftfreq | |
from scipy.signal import find_peaks | |
from plotly.subplots import make_subplots | |
class FourierTransform: | |
def __init__(self, x, time_step: int=1): | |
""" | |
Convenience class for computing the Fast Fourier Transform. | |
Parameters | |
---------- | |
x: array_like | |
The signal (input array can be complex) to be transformed. | |
time_step: int; default = 1 | |
Sample spacing (inverse of the sampling rate). | |
""" | |
self.x = x | |
self.time_step = time_step | |
# tracks whether self.transform() has been called | |
self.__fit = False | |
def transform(self, **kwargs): | |
""" | |
Compute the one-dimensional discrete Fourier Transform | |
using `scipy.fft.fft`. | |
Parameters | |
---------- | |
kwargs: | |
Additional keyword arguments passed to | |
`scipy.fft.fft` and `scipy.signal.find_peaks` | |
""" | |
self.xfft = fft( | |
self.x, | |
n=kwargs.get('n', None), | |
axis=kwargs.get('axis', -1), | |
norm=kwargs.get('norm', None), | |
overwrite_x=kwargs.get('overwrite_x', False), | |
workers=kwargs.get('workers', None) | |
) | |
self.xfft_freq = fftfreq(self.x.size, self.time_step) | |
# compute the amplitude, power, and phase of each component | |
# https://numpy.org/devdocs/reference/routines.fft.html#implementation-details | |
# howevever, not sure if this is right? | |
self.amp = np.abs(self.xfft) | |
self.power = self.amp ** 2 | |
self.phase = np.angle(self.xfft) | |
# find the peak frequencies using the power | |
self._idx_peaks = find_peaks( | |
self.power, | |
height=kwargs.get('height', None), | |
threshold=kwargs.get('threshold', None), | |
distance=kwargs.get('distance', None), | |
prominence=kwargs.get('prominence', None), | |
width=kwargs.get('width', None), | |
wlen=kwargs.get('wlen', None), | |
rel_height=kwargs.get('rel_height', 0.5), | |
plateau_size=kwargs.get('plateau_size', None) | |
)[0] # returns a tuple by default | |
# store the parameters at the peak frequencies | |
# FIXME - change this interface? | |
self.peak_freq = self.xfft_freq[self._idx_peaks] | |
self.peak_amp = self.amp[self._idx_peaks] | |
self.peak_power = self.power[self._idx_peaks] | |
self.peak_phase = self.phase[self._idx_peaks] | |
# finally mark the transformation as a success | |
self.__fit = True | |
fft_filtered = np.zeros(self.xfft.shape, dtype='complex128') | |
fft_filtered[self._idx_peaks] = self.xfft[self._idx_peaks] | |
self.x_filtered = ifft(fft_filtered) | |
def plot_trend(self): | |
""" | |
Interactive plot of the original trend, as well as | |
the filtered signal if `self.transform()` has been used. | |
Returns | |
------- | |
plotly.graph_objects.Figure | |
""" | |
fig = go.Figure() | |
title = 'Original Signal' | |
# use an arbitrary time domain | |
t = np.arange(0, self.x.shape[0]) | |
fig.add_trace(go.Scatter( | |
x=t, | |
y=self.x, | |
mode='markers+lines', | |
name='Original' | |
)) | |
if self.__fit: | |
fig.add_trace(go.Scatter( | |
x=t, | |
y=self.x_filtered.real, | |
mode='lines', | |
name='IFFT' | |
)) | |
title = 'Original Signal vs Inverse Fourier Transform' | |
fig.update_layout( | |
xaxis={ | |
'title': 'Time Units' | |
}, | |
yaxis={ | |
'title': 'Signal' | |
}, | |
title=title | |
) | |
return fig | |
def plot_diagnostics(self): | |
""" | |
Interactive diangostic plot of the power, amplitude, and phase | |
found for each frequency in the decomposition. | |
Returns | |
------- | |
plotly.graph_objects.Figure | |
""" | |
if not self.__fit: | |
raise ValueError('Transform the data first with `self.transform()` before plotting diagnostics.') | |
fig = make_subplots(rows=3, cols=1, shared_xaxes=True) | |
diagnostics = { | |
'Power': self.power, | |
'Amplitude': self.amp, | |
'Phase': self.phase | |
} | |
for i, items in enumerate(diagnostics.items()): | |
name, diag = items | |
fig.add_trace( | |
go.Bar( | |
x=self.xfft_freq, | |
y=diag, | |
name=name | |
), | |
row=i+1, | |
col=1 | |
) | |
fig.add_trace( | |
go.Scatter( | |
x=self.xfft_freq[self._idx_peaks], | |
y=diag[self._idx_peaks], | |
mode='markers', | |
name=f'{name} - Peak Frequency' | |
), | |
row=i+1, | |
col=1 | |
) | |
if i+1 == 3: | |
fig.update_xaxes( | |
{ | |
'title': 'Frequency [Hz]', | |
}, | |
row=i+1, | |
col=1 | |
) | |
fig.update_yaxes( | |
{ | |
'title': name, | |
}, | |
row=i+1, | |
col=1 | |
) | |
fig.update_layout( | |
title='Fourier Transform' | |
) | |
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# make an example dataset consisting of three sine curves | |
time_step = 1 | |
y_arrs = [] | |
amplitudes = (6, 3, 12) | |
periods = (15, 30, 180) | |
phases = (np.pi/2, np.pi, 3*np.pi/2) | |
t = np.arange(0, 365, time_step) | |
# plot the curves as we build them | |
fig = go.Figure() | |
for a, p, phi in zip(amplitudes, periods, phases): | |
freq = 1/p | |
y = a * np.sin(2*np.pi*freq*t + phi) | |
y_arrs.append(y) | |
name = f'$A = {a}, f = {freq:.3f}, \\phi = {phi:.3f}$' | |
fig.add_trace(go.Scatter( | |
x=t, | |
y=y, | |
mode='markers+lines', | |
name=name | |
)) | |
# combine the curves into a single signal | |
y = np.array(y_arrs).sum(axis=0) | |
fig.add_trace(go.Scatter( | |
x=t, | |
y=y, | |
mode='markers+lines', | |
name='Total' | |
)) | |
fig.update_layout( | |
xaxis={ | |
'title': 'Time [s]' | |
}, | |
yaxis={ | |
'title': 'Amplitude' | |
}, | |
title='Original Signal' | |
) | |
fig.show() | |
# use FFT to decompose the signal | |
ft = FourierTransform() | |
ft.transform() | |
# visualize peak frequencies | |
ft.plot_diagnostics().show() | |
# visualize original vs filtered IFFT | |
ft.plot_trend().show() | |
# get the filtered signal from the fit | |
y_filtered = ft.x_filtered |
Author
yanniskatsaros
commented
Apr 29, 2020
Still can't quite figure out how to get the amplitude
and phase
as "parameters" from the fit. I've tried something like the following but it doesn't match what I'd expect from my three decomposed input signals. (The frequency does match, however).
A = ft.peak_amp.reshape(-1, 1)
f = ft.peak_freq.reshape(-1, 1)
phi = ft.peak_phase.reshape(-1, 1)
y_decomp = A * np.sin(2*np.pi*f*t + phi)
# this ends up with a monotonic straight-ish line - sad day
y_hat = y_decomp.sum(axis=0)
>>> A
array([[2191.09720184],
[ 522.84866425],
[ 900.21792866],
[ 900.21792866],
[ 522.84866425],
[2191.09720184]])
>>> f
array([[ 0.00547945],
[ 0.03287671],
[ 0.06575342],
[-0.06575342],
[-0.03287671],
[-0.00547945]])
>>> phi
array([[-3.0568728 ],
[ 2.04664007],
[ 1.02881312],
[-1.02881312],
[-2.04664007],
[ 3.0568728 ]])
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment