Skip to content

Instantly share code, notes, and snippets.

@yanniskatsaros
Created April 29, 2020 15:02
Show Gist options
  • Save yanniskatsaros/6520ac2ac66e16d8c8534699c42d272c to your computer and use it in GitHub Desktop.
Save yanniskatsaros/6520ac2ac66e16d8c8534699c42d272c to your computer and use it in GitHub Desktop.
Convenience class for simplifying the process of computing FFT and IFFT. Provides helpful interactive diagnostic and trend plots using Plotly.
import numpy as np
import plotly.graph_objects as go
from scipy.fft import fft, ifft, fftfreq
from scipy.signal import find_peaks
from plotly.subplots import make_subplots
class FourierTransform:
def __init__(self, x, time_step: int=1):
"""
Convenience class for computing the Fast Fourier Transform.
Parameters
----------
x: array_like
The signal (input array can be complex) to be transformed.
time_step: int; default = 1
Sample spacing (inverse of the sampling rate).
"""
self.x = x
self.time_step = time_step
# tracks whether self.transform() has been called
self.__fit = False
def transform(self, **kwargs):
"""
Compute the one-dimensional discrete Fourier Transform
using `scipy.fft.fft`.
Parameters
----------
kwargs:
Additional keyword arguments passed to
`scipy.fft.fft` and `scipy.signal.find_peaks`
"""
self.xfft = fft(
self.x,
n=kwargs.get('n', None),
axis=kwargs.get('axis', -1),
norm=kwargs.get('norm', None),
overwrite_x=kwargs.get('overwrite_x', False),
workers=kwargs.get('workers', None)
)
self.xfft_freq = fftfreq(self.x.size, self.time_step)
# compute the amplitude, power, and phase of each component
# https://numpy.org/devdocs/reference/routines.fft.html#implementation-details
# howevever, not sure if this is right?
self.amp = np.abs(self.xfft)
self.power = self.amp ** 2
self.phase = np.angle(self.xfft)
# find the peak frequencies using the power
self._idx_peaks = find_peaks(
self.power,
height=kwargs.get('height', None),
threshold=kwargs.get('threshold', None),
distance=kwargs.get('distance', None),
prominence=kwargs.get('prominence', None),
width=kwargs.get('width', None),
wlen=kwargs.get('wlen', None),
rel_height=kwargs.get('rel_height', 0.5),
plateau_size=kwargs.get('plateau_size', None)
)[0] # returns a tuple by default
# store the parameters at the peak frequencies
# FIXME - change this interface?
self.peak_freq = self.xfft_freq[self._idx_peaks]
self.peak_amp = self.amp[self._idx_peaks]
self.peak_power = self.power[self._idx_peaks]
self.peak_phase = self.phase[self._idx_peaks]
# finally mark the transformation as a success
self.__fit = True
fft_filtered = np.zeros(self.xfft.shape, dtype='complex128')
fft_filtered[self._idx_peaks] = self.xfft[self._idx_peaks]
self.x_filtered = ifft(fft_filtered)
def plot_trend(self):
"""
Interactive plot of the original trend, as well as
the filtered signal if `self.transform()` has been used.
Returns
-------
plotly.graph_objects.Figure
"""
fig = go.Figure()
title = 'Original Signal'
# use an arbitrary time domain
t = np.arange(0, self.x.shape[0])
fig.add_trace(go.Scatter(
x=t,
y=self.x,
mode='markers+lines',
name='Original'
))
if self.__fit:
fig.add_trace(go.Scatter(
x=t,
y=self.x_filtered.real,
mode='lines',
name='IFFT'
))
title = 'Original Signal vs Inverse Fourier Transform'
fig.update_layout(
xaxis={
'title': 'Time Units'
},
yaxis={
'title': 'Signal'
},
title=title
)
return fig
def plot_diagnostics(self):
"""
Interactive diangostic plot of the power, amplitude, and phase
found for each frequency in the decomposition.
Returns
-------
plotly.graph_objects.Figure
"""
if not self.__fit:
raise ValueError('Transform the data first with `self.transform()` before plotting diagnostics.')
fig = make_subplots(rows=3, cols=1, shared_xaxes=True)
diagnostics = {
'Power': self.power,
'Amplitude': self.amp,
'Phase': self.phase
}
for i, items in enumerate(diagnostics.items()):
name, diag = items
fig.add_trace(
go.Bar(
x=self.xfft_freq,
y=diag,
name=name
),
row=i+1,
col=1
)
fig.add_trace(
go.Scatter(
x=self.xfft_freq[self._idx_peaks],
y=diag[self._idx_peaks],
mode='markers',
name=f'{name} - Peak Frequency'
),
row=i+1,
col=1
)
if i+1 == 3:
fig.update_xaxes(
{
'title': 'Frequency [Hz]',
},
row=i+1,
col=1
)
fig.update_yaxes(
{
'title': name,
},
row=i+1,
col=1
)
fig.update_layout(
title='Fourier Transform'
)
return fig
# make an example dataset consisting of three sine curves
time_step = 1
y_arrs = []
amplitudes = (6, 3, 12)
periods = (15, 30, 180)
phases = (np.pi/2, np.pi, 3*np.pi/2)
t = np.arange(0, 365, time_step)
# plot the curves as we build them
fig = go.Figure()
for a, p, phi in zip(amplitudes, periods, phases):
freq = 1/p
y = a * np.sin(2*np.pi*freq*t + phi)
y_arrs.append(y)
name = f'$A = {a}, f = {freq:.3f}, \\phi = {phi:.3f}$'
fig.add_trace(go.Scatter(
x=t,
y=y,
mode='markers+lines',
name=name
))
# combine the curves into a single signal
y = np.array(y_arrs).sum(axis=0)
fig.add_trace(go.Scatter(
x=t,
y=y,
mode='markers+lines',
name='Total'
))
fig.update_layout(
xaxis={
'title': 'Time [s]'
},
yaxis={
'title': 'Amplitude'
},
title='Original Signal'
)
fig.show()
# use FFT to decompose the signal
ft = FourierTransform()
ft.transform()
# visualize peak frequencies
ft.plot_diagnostics().show()
# visualize original vs filtered IFFT
ft.plot_trend().show()
# get the filtered signal from the fit
y_filtered = ft.x_filtered