Skip to content

Instantly share code, notes, and snippets.

@istupakov
Last active December 17, 2023 03:20
Show Gist options
  • Save istupakov/a8e0ed97aacd3d827649a297f2bcabc9 to your computer and use it in GitHub Desktop.
Save istupakov/a8e0ed97aacd3d827649a297f2bcabc9 to your computer and use it in GitHub Desktop.
Online STFT on C++ (fftw)
#include <vector>
#include <span>
#include <complex>
#include <cmath>
#include <numbers>
#include <fftw3.h>
using std::complex;
using std::span;
class OnlineStft
{
template <class T>
class fftw_allocator
{
public:
typedef T value_type;
fftw_allocator() noexcept {}
template <class U>
fftw_allocator(fftw_allocator<U>) noexcept {}
T *allocate(size_t n)
{
return static_cast<T *>(fftw_malloc(n * sizeof(T)));
}
void deallocate(T *p, size_t)
{
return fftw_free(static_cast<void *>(p));
}
friend bool operator==(fftw_allocator, fftw_allocator) { return true; }
friend bool operator!=(fftw_allocator, fftw_allocator) { return false; }
};
fftw_plan r2c_plan;
fftw_plan c2r_plan;
std::vector<double, fftw_allocator<double>> in_buffer;
std::vector<complex<double>, fftw_allocator<complex<double>>> out_buffer;
span<double> in_buffer_first;
span<double> in_buffer_last;
std::vector<double> in_win;
std::vector<double> out_win;
std::vector<double> prev_in_frame;
std::vector<double> prev_out_frame;
int num_freq;
public:
OnlineStft(int num_channels, int frame_size = 2048)
{
int n = 2 * frame_size;
num_freq = frame_size + 1;
in_buffer.assign(n * num_channels, 0);
out_buffer.assign(num_freq * num_channels, 0);
in_buffer_first = span(in_buffer).first(frame_size * num_channels);
in_buffer_last = span(in_buffer).last(frame_size * num_channels);
r2c_plan = fftw_plan_many_dft_r2c(1, &n, num_channels,
in_buffer.data(), nullptr, num_channels, 1,
reinterpret_cast<fftw_complex *>(out_buffer.data()), nullptr, num_channels, 1,
FFTW_MEASURE);
c2r_plan = fftw_plan_many_dft_c2r(1, &n, num_channels,
reinterpret_cast<fftw_complex *>(out_buffer.data()), nullptr, num_channels, 1,
in_buffer.data(), nullptr, num_channels, 1,
FFTW_MEASURE);
prev_in_frame.assign(frame_size * num_channels, 0);
prev_out_frame.assign(frame_size * num_channels, 0);
in_win.resize(n * num_channels);
for (int i = 0; i < n; ++i)
for (int j = 0; j < num_channels; ++j)
in_win[i * num_channels + j] = 0.54 - 0.46 * std::cos(2 * std::numbers::pi * i / n); // Hamming window
out_win.resize(n * num_channels);
for (int i = 0; i < n; ++i)
for (int j = 0; j < num_channels; ++j)
{
double a = in_win[i * num_channels + j];
double b = in_win[((i + frame_size) % n) * num_channels + j];
out_win[i * num_channels + j] = in_win[i * num_channels + j] / (a * a + b * b) / n;
}
}
~OnlineStft()
{
fftw_destroy_plan(c2r_plan);
fftw_destroy_plan(r2c_plan);
}
int get_num_freq()
{
return num_freq;
}
span<complex<double>> forward(span<double> frame)
{
std::copy(prev_in_frame.begin(), prev_in_frame.end(), in_buffer_first.begin());
std::copy(frame.begin(), frame.end(), in_buffer_last.begin());
std::copy(frame.begin(), frame.end(), prev_in_frame.begin());
for (int i = 0; i < in_buffer.size(); ++i)
in_buffer[i] *= in_win[i];
fftw_execute(r2c_plan);
return out_buffer;
}
span<double> backward(span<complex<double>> image)
{
std::copy(image.begin(), image.end(), out_buffer.begin());
fftw_execute(c2r_plan);
for (int i = 0; i < in_buffer.size(); ++i)
in_buffer[i] *= out_win[i];
for (int i = 0; i < in_buffer_first.size(); ++i)
in_buffer_first[i] += prev_out_frame[i];
std::copy(in_buffer_last.begin(), in_buffer_last.end(), prev_out_frame.begin());
return in_buffer_first;
}
};
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment