Skip to content

Instantly share code, notes, and snippets.

@coderforlife
Last active February 28, 2023 04:46
Show Gist options
  • Save coderforlife/17bf471a519a7d092e82f95f10e8f6e4 to your computer and use it in GitHub Desktop.
Save coderforlife/17bf471a519a7d092e82f95f10e8f6e4 to your computer and use it in GitHub Desktop.
Implementation of scipy.signal.lfilter and scipy.signal.filtfilt with cupy
import cupy
def axis_replace(shape, axis, value):
if axis < 0:
axis += len(shape)
return shape[:axis] + (value,) + shape[axis+1:]
def axis_seq(axis, ndim, value, default):
if axis < 0:
axis += ndim
default = (default,)
return default*axis + (value,) + default*(ndim-axis-1)
def axis_broadcast(a, ndim, axis=-1):
return a[axis_seq(axis, ndim, slice(None), None)]
# or:
# return cupy.broadcast_to(a, axis_seq(axis, ndim, a.size, 1))
def axis_slice(a, start=None, stop=None, step=None, axis=-1):
# See scipy.signal._arraytools.axis_slice for documentation
return a[axis_seq(axis, a.ndim, slice(start, stop, step), slice(None))]
def axis_reverse(a, axis=-1):
# See scipy.signal._arraytools.axis_reverse for documentation
return axis_slice(a, step=-1, axis=axis)
def odd_ext(x, n, axis=-1):
# See scipy.signal._arraytools.odd_ext for documentation
if n < 1:
return x
if n > x.shape[axis] - 1:
raise ValueError(("The extension length n (%d) is too big. " +
"It must not exceed x.shape[axis]-1, which is %d.")
% (n, x.shape[axis] - 1))
L = 2*axis_slice(x, stop=1, axis=axis)
R = 2*axis_slice(x, start=-1, axis=axis)
left_ext = L - axis_slice(x, start=n, stop=0, step=-1, axis=axis)
right_ext = R - axis_slice(x, start=-2, stop=-n-2, step=-1, axis=axis)
return cupy.concatenate((left_ext, x, right_ext), axis=axis)
def even_ext(x, n, axis=-1):
# See scipy.signal._arraytools.even_ext for documentation
if n < 1:
return x
if n > x.shape[axis] - 1:
raise ValueError(("The extension length n (%d) is too big. " +
"It must not exceed x.shape[axis]-1, which is %d.")
% (n, x.shape[axis] - 1))
return cupy.pad(x, axis_seq(axis, x.ndim, (n, n), (0, 0)), 'reflect')
def const_ext(x, n, axis=-1):
# See scipy.signal._arraytools.const_ext for documentation
if n < 1:
return x
return cupy.pad(x, axis_seq(axis, x.ndim, (n, n), (0, 0)), 'edge')
def zero_ext(x, n, axis=-1):
# See scipy.signal._arraytools.zero_ext for documentation
if n < 1:
return x
return cupy.pad(x, axis_seq(axis, x.ndim, (n, n), (0, 0)), 'constant')
import cupy
import ._arraytools as _arr
def lfilter(b, a, x, axis=-1, zi=None):
if a.ndim != 1 or b.ndim != 1:
raise ValueError('object of too small depth for desired array')
if x.ndim == 0:
raise ValueError('x must be at least 1-D')
axis = cupy.util._normalize_axis_index(axis, x.ndim)
if a[0] == 0:
raise ValueError("a[0] cannot be 0")
dtype = cupy.result_type(b, a, x)
n_filt = max(a.size, b.size)
n_samples = x.shape[axis]
n_signals = x.size // n_samples
z_shape = _arr.axis_replace(x.shape, axis, n_filt - 1)
if zi is not None:
# do not broadcast zi, but do expand singleton dims
if zi.ndim != x.ndim:
raise ValueError('object of too small depth for desired array')
# check the trivial case where zi is the right shape first
if zi.shape != z_shape:
zi_strides = list(zi.strides)
for k in range(x.ndim):
if k != axis and zi.shape[k] == 1 and z_shape[k] != 1:
zi_strides[k] = 0
elif zi.shape[k] != z_shape[k]:
raise ValueError(
'Unexpected shape for zi: expected {}, found {}'.
format(z_shape, zi.shape))
zi = cupy.lib.stride_tricks.as_strided(zi, z_shape, zi_strides)
dtype = cupy.result_type(dtype, zi)
# Normalize arrays
needs_scaling = a[0] != 1
b = b.astype(dtype, order='C', copy=needs_scaling)
a = a.astype(dtype, order='C', copy=needs_scaling)
if needs_scaling:
b /= a[0]
a /= a[0]
x = x.astype(dtype, copy=False)
# Apply the numerator coefficients
if zi is not None and b.size < a.size:
b = cupy.pad(b, (0, a.size-b.size))
# TODO: eventually the next line should be method='auto'
# TODO: you need to use https://github.com/cupy/cupy/pull/3748 (soon to be merged as of Aug 13 2020)
out_full = convolve(x, _arr.axis_broadcast(b, x.ndim, axis), method='direct')
if zi is not None:
_arr.axis_slice(out_full, stop=n_filt-1, axis=axis)[...] += zi
if len(a) == 1:
# No denominator
if zi is None:
return _arr.axis_slice(out_full, stop=n_samples, axis=axis)
out = _arr.axis_slice(out_full, stop=n_samples, axis=axis)
zf = _arr.axis_slice(out_full, start=n_samples, axis=axis)
return out, zf
# Apply the denominator coefficients
int_dtype = cupy.int32 if x.size < (1 << 31) else cupy.int64
kernel = _get_lfilter_kernel(
axis, x.ndim, a.size, zi is not None,
cupy.core._scalar.get_typename(dtype),
cupy.core._scalar.get_typename(int_dtype))
info = cupy.array((n_signals, n_samples) + out_full.shape, dtype=int_dtype)
kernel(((n_signals+128-1)//128,), (128,), (a, out_full, info))
if zi is None:
return _arr.axis_slice(out_full, stop=n_samples, axis=axis)
out = _arr.axis_slice(out_full, stop=n_samples, axis=axis)
zf = _arr.axis_slice(out_full, start=n_samples, axis=axis)
return out, zf
_FILTER_GENERAL = '''
#include "cupy/carray.cuh"
typedef unsigned char byte;
typedef {val_type} T;
typedef {int_type} idx_t;
template <typename T>
__device__ T* row(T* ptr, idx_t i, idx_t axis, idx_t ndim, const idx_t* shape) {{
idx_t index = 0, stride = 1;
for (idx_t a = ndim; --a > 0; ) {{
if (a != axis) {{
index += (i % shape[a]) * stride;
i /= shape[a];
}}
stride *= shape[a];
}}
return ptr + index + stride * i;
}}
'''
@cupy.util.memoize(for_each_device=True)
def _get_lfilter_kernel(axis, ndim, n_filt, compute_zf, val_type, int_type):
if compute_zf:
zf_computation = '''
for (idx_t n = n_samples; n < n_samples + {n_filt} - 1; ++n) {{
T accuml = 0;
for (idx_t f = n - n_samples + 1; f < stop; f++) {{
accuml += a[f]*y_i[-f*y_elem_stride];
}}
y_i[0] -= accuml;
y_i += y_elem_stride;
}}
'''.format(n_filt=n_filt)
else:
zf_computation = ''
return cupy.RawKernel(
_FILTER_GENERAL.format(val_type=val_type, int_type=int_type)+'''
extern "C" __global__
void cupyx_lfilter(const T* __restrict__ a, T* __restrict__ y,
idx_t* __restrict__ info) {{
const idx_t n_signals = info[0], n_samples = info[1],
* __restrict__ shape = info+2;
const idx_t stop = min({n_filt}, n_samples);
idx_t y_elem_stride = 1;
for (int a = {axis}; ++a < {ndim}; ) {{ y_elem_stride *= shape[a]; }}
CUPY_FOR (i, n_signals) {{
T* __restrict__ y_i = row(y, i, {axis}, {ndim}, shape);
y_i += y_elem_stride;
for (idx_t n = 1; n < stop; ++n) {{
T accuml = 0;
for (idx_t f = 1; f <= n; f++) {{
accuml += a[f]*y_i[-f*y_elem_stride];
}}
y_i[0] -= accuml;
y_i += y_elem_stride;
}}
for (idx_t n = {n_filt}; n < n_samples; ++n) {{
T accuml = 0;
for (idx_t f = 1; f < {n_filt}; f++) {{
accuml += a[f]*y_i[-f*y_elem_stride];
}}
y_i[0] -= accuml;
y_i += y_elem_stride;
}}
{zf_comp}
}}
}}'''.format(axis=axis, ndim=ndim, n_filt=n_filt, zf_comp=zf_computation),
'cupyx_lfilter')
def lfilter_zi(b, a):
"""Construct initial conditions for lfilter for step response steady-state.
Compute an initial state ``zi`` for the ``lfilter`` function that
corresponds to the steady state of the step response.
A typical use of this function is to set the initial state so that the
output of the filter starts at the same value as the first element of
the signal to be filtered.
Args:
b, a (cupy.ndarray): The IIR filter coefficients. See ``lfilter`` for
more information.
Returns:
cupy.ndarray: The initial state for the filter.
.. seealso:: :func:`scipy.signal.lfilter_zi`
.. seealso:: :func:`cupyx.scipy.signal.lfilter`
.. seealso:: :func:`cupyx.scipy.signal.lfiltic`
"""
if b.ndim != 1:
raise ValueError("Numerator b must be 1-D.")
if a.ndim != 1:
raise ValueError("Denominator a must be 1-D.")
a = cupy.trim_zeros(a, 'f')
if a.size < 2:
raise ValueError("Must have at least two coefficients after removing "
"leading 0s")
if a[0] != 1.0:
# Normalize the coefficients so a[0] == 1.
b = b / a[0]
a = a / a[0]
# Pad a or b with zeros so they are the same length.
n = max(len(a), len(b))
if len(a) < n:
a = cupy.r_[a, cupy.zeros(n - len(a))]
elif len(b) < n:
b = cupy.r_[b, cupy.zeros(n - len(b))]
# Solve zi = A*zi + B
A = a[1:].copy()
A[0] += 1
B = b[1:] - a[1:] * b[0]
zi = cupy.empty(n - 1)
zi[0] = zi_0 = B.sum() / A.sum()
zi_ = A[:-1].cumsum(out=zi[1:])
zi_ *= zi_0
zi_ -= B[:-1].cumsum(out=B[:-1])
return zi
def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
irlen=None):
"""Apply a digital filter forward and backward to a signal.
This function applies a linear digital filter twice, once forward and once
backwards. The combined filter has zero phase and a filter order twice that
of the original.
The function provides options for handling the edges of the signal. The
function ``sosfiltfilt`` (and filter design using ``output='sos'``) should
be preferred over ``filtfilt`` for most filtering tasks, as second-order
sections have fewer numerical problems.
Args:
b (cupy.ndarray): The numerator coefficient vector of the filter.
a (cupy.ndarray): The denominator coefficient vector of the filter. If
``a[0]`` is not 1, then ``a`` and ``b`` are normalized by ``a[0]``.
x (cupy.ndarray): The array of data to be filtered.
axis (int, optional): The axis of ``x`` to which the filter is applied.
Default is -1.
padtype (str or None, optional): Must be 'odd', 'even', 'constant', or
None. This determines the type of extension to use for the padded
signal to which the filter is applied. If ``padtype`` is None, no
padding is used. The default is 'odd'.
padlen (int or None, optional): The number of elements by which to
extend ``x`` at both ends of ``axis`` before applying the filter.
This value must be less than ``x.shape[axis] - 1``. ``padlen=0``
implies no padding. The default value is ``3*max(len(a), len(b))``.
method (str, optional): Determines the method for handling the edges of
the signal, either "pad" or "gust". When "pad", the signal is
padded; the type of padding is determined by ``padtype`` and
``padlen``, and ``irlen`` is ignored. When "gust", Gustafsson's
method is used, and ``padtype`` and ``padlen`` are ignored.
irlen (int or None, optional): When ``method`` is "gust", ``irlen``
specifies the length of the impulse response of the filter. If
``irlen`` is None, no part of the impulse response is ignored. For
a long signal, specifying ``irlen`` can significantly improve the
performance of the filter.
Returns:
cupy.ndarray: The filtered output with the same shape as ``x``.
.. seealso:: :func:`scipy.signal.sosfilt`
.. seealso:: :func:`scipy.signal.lfilter`
.. seealso:: :func:`scipy.signal.lfilter_zi`
.. seealso:: :func:`scipy.signal.lfiltic`
.. seealso:: :func:`scipy.signal.savgol_filter`
.. seealso:: :func:`cupyx.scipy.signal.filtfilt`
"""
if method not in ("pad", "gust"):
raise ValueError("method must be 'pad' or 'gust'.")
axis = cupy.util._normalize_axis_index(axis, x.ndim)
if method == "gust":
raise NotImplementedError("method='gust' not implemented yet")
# y, z1, z2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
# return y
# method == "pad"
edge, ext = _validate_pad(padtype, padlen, x, axis,
ntaps=max(len(a), len(b)))
# Get the steady state of the filter's step response.
zi = lfilter_zi(b, a)
# Reshape zi and create x0 so that zi*x0 broadcasts to the correct value
# for the 'zi' keyword argument to lfilter
zi = _arr.axis_broadcast(zi, x.ndim, axis=axis)
x0 = _arr.axis_slice(ext, stop=1, axis=axis)
# Forward filter
y, _ = lfilter(b, a, ext, axis=axis, zi=zi * x0)
# Backward filter
# Create y0 so zi*y0 broadcasts appropriately
y0 = _arr.axis_slice(y, start=-1, axis=axis)
y, _ = lfilter(b, a, _arr.axis_reverse(y, axis=axis), axis=axis, zi=zi*y0)
# Reverse y
y = _arr.axis_reverse(y, axis=axis)
if edge > 0:
# Slice the actual signal from the extended signal
y = _arr.axis_slice(y, start=edge, stop=-edge, axis=axis)
return y
def _validate_pad(padtype, padlen, x, axis, ntaps):
if padtype not in ('even', 'odd', 'constant', None):
raise ValueError(("Unknown value '%s' given to padtype. padtype "
"must be 'even', 'odd', 'constant', or None.") %
padtype)
if padtype is None:
padlen = 0
edge = ntaps*3 if padlen is None else padlen
# x's 'axis' dimension must be bigger than edge.
if x.shape[axis] <= edge:
raise ValueError("The length of the input vector x must be greater "
"than padlen, which is %d." % edge)
if padtype is not None and edge > 0:
# Make an extension of length edge at each end of the input array.
if padtype == 'even':
ext = _arr.even_ext(x, edge, axis=axis)
elif padtype == 'odd':
ext = _arr.odd_ext(x, edge, axis=axis)
else:
ext = _arr.const_ext(x, edge, axis=axis)
else:
ext = x
return edge, ext
@ChenPaulYu
Copy link

Hi author,

this implementation is incredible, however it only enable the 1D waveform to do the filtering, do you have any idea how to add a batch dimension in this implementation ?

@coderforlife
Copy link
Author

@ChenPaulYu This version is 1 year old and I have made progress since then, including 2D support and improving the speed quite a bit. Hopefully in the next few months it will be fully published within cupy itself or possibly in rapidsai/cusignal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment