Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active August 2, 2021 11:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save andres-fr/c23da36444ebbcc3343b8829040182aa to your computer and use it in GitHub Desktop.
Save andres-fr/c23da36444ebbcc3343b8829040182aa to your computer and use it in GitHub Desktop.
C++ spectral convolution of 1D signals (OpenMP+FFTW)
// OverlapSaveConvolver: A small C++11 single-file program that performs
// efficient 1D convolution and cross-correlation of two float arrays.
// Copyright (C) 2017 Andres Fernandez (https://github.com/andres-fr)
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation; either version 3 of the License, or
// (at your option) any later version.
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software Foundation,
// Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
// TODO:
// - Signal arrays are aligned, but SIMDization wasn't explicitly benchmarked:
// explicitly SIMDize SpectralConvolution and SpectralCorrelation: https://github.com/VcDevel/Vc
// - Add unit testing with catch
// - Add and use proper benchmarking lib
// - Google styleguide https://google.github.io/styleguide/cppguide.html
// g++ -O3 -std=c++11 -Wall -Wextra overlap_save_convolver.cpp -fopenmp -lfftw3f -o test && valgrind --leak-check=full -v ./test
#define REAL 0
#define IMAG 1
// comment this line to deactivate OpenMP for loop parallelizations, or if you want to debug
// memory management (valgrind reports OMP normal activity as error).
// the number is the minimum size that a 'for' loop needs to get sent to OMP (1=>always sent)
#define WITH_OPENMP_ABOVE 1
//
#include <string.h>
#include <math.h>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <list>
#include <initializer_list>
#include <iterator>
#include <algorithm>
//
#include <fftw3.h>
#ifdef WITH_OPENMP_ABOVE
# include <omp.h>
#endif
//
using namespace std;
////////////////////////////////////////////////////////////////////////////////////////////////////
/// TYPECHECK/ANTIBUGGING
////////////////////////////////////////////////////////////////////////////////////////////////////
// Given a container or its beginning and end iterables, converts the container to a string
// of the form {a, b, c} (like a basic version of Python's __str__). Usage example:
// vector<string> c1({"foo", "bar"});
// vector<size_t> c2({1});
// list<double> c3({1,2,3,4,5});
// vector<bool> c4({false, true, false});
// list<int> c5;
// cout << IterableToString({1.23, 4.56, -789.0}) << endl;
// cout << IterableToString(c1) << endl;
// cout << IterableToString({"hello", "hello"}) << endl;
// cout << IterableToString(c2) << endl;
// cout << IterableToString(c3) << endl;
// cout << IterableToString(c4) << endl;
// cout << IterableToString(c5.begin(), c5.end()) << endl;
template<typename T>
string IterableToString(T it, T end){
stringstream ss;
ss << "{";
bool first = true;
for (; it!=end; ++it){
if (first){
ss << *it;
first = false;
} else {
ss << ", " << *it;
}
} ss << "}";
return ss.str();
}
template <class C> // Overload IterableToString to directly accept any Collection like vector<int>
string IterableToString(const C &c){
return IterableToString(c.begin(), c.end());
}
template <class T> // Overload IterableToString to directly accept initializer_lists
string IterableToString(const initializer_list<T> c){
return IterableToString(c.begin(), c.end());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Given a container or its beginning and end iterables, checks wether all values contained in the
// iterable are equal and raises an exception if not. Usage example:
// vector<size_t> v1({});
// vector<double> v2({123.4, 123.4, 123.4});
// vector<bool> v3({false, false, false});
// vector<size_t> v4({1});
// vector<string> v5({"hello", "hello", "bye"});
// CheckAllEqual({3,3,3,3,3,3,3,3});
// CheckAllEqual(v1);
// CheckAllEqual(v2.begin(), v2.end());
// CheckAllEqual(v3);
// CheckAllEqual(v4);
// CheckAllEqual(v5.begin(), prev(v5.end()));
// CheckAllEqual(v5);
template<class I>
void CheckAllEqual(const I beg, const I end, const string &message="CheckAllEqual"){
I it = beg;
bool all_eq = true;
auto last = (it==end)? end : prev(end);
for(;it!=last; ++it){
all_eq &= (*(it)==*(next(it)));
if (!all_eq) {
throw runtime_error(string("[ERROR] ") + message+" "+IterableToString(beg, end));
}
}
}
template <class C>
void CheckAllEqual(const C &c, const string message="CheckAllEqual"){
CheckAllEqual(c.begin(), c.end(), message);
}
template <class T>
void CheckAllEqual(const initializer_list<T> c, const string message="CheckAllEqual"){
CheckAllEqual(c.begin(), c.end(), message);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Raises an exception if complex_size!=(real_size/2+1), being "/" an integer division.
void CheckRealComplexRatio(const size_t real_size, const size_t complex_size,
const string func_name="CheckRealComplexRatio"){
if(complex_size!=(real_size/2+1)){
throw runtime_error(string("[ERROR] ") + func_name +
": size of ComplexSignal must equal size(FloatSignal)/2+1. " +
" Sizes were (float, complex): " +
IterableToString({real_size, complex_size}));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Abstract function that performs a comparation between any 2 elements, and if the comparation
// returns a truthy value raises an exception with the given message.
template <class T, class Functor>
void CheckTwoElements(const T a, const T b, const Functor &binary_predicate,
const string message){
if(binary_predicate(a,b)){
throw runtime_error(string("[ERROR] ") + message + " " + IterableToString({a, b}));
}
}
// Raises an exception with the given message if a>b.
void check_a_less_equal_b(const size_t a, const size_t b,
const string message="a was greater than b!"){
CheckTwoElements(a, b, [](const size_t a, const size_t b){return a>b;}, message);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// HELPERS
////////////////////////////////////////////////////////////////////////////////////////////////////
size_t Pow2Ceil(size_t x){return pow(2, ceil(log2(x)));}
// This is an abstract base class that provides some basic, type-independent functionality for
// any container that should behave as a signal. It is not intended to be instantiated directly.
template <class T>
class Signal {
protected:
T* data_;
size_t size_;
public:
// Given a size and a reference to an array, it fills the array with <SIZE> zeros.
// Therefore, **IT DELETES THE CONTENTS OF THE ARRAY**. It is intended to be passed a newly
// allocated array by the classes that inherit from Signal, because it isn't an expensive
// operation and avoids memory errors due to non-initialized values.
explicit Signal(T* data, size_t size) : data_(data), size_(size){
memset(data_, 0, sizeof(T)*size);
}
// The destructor is empty because this class didn't allocate the contained array
virtual ~Signal(){}
// getters
size_t &getSize(){return size_;}
const size_t &getSize() const{return size_;}
T* getData(){return data_;}
const T* getData() const{return data_;}
// overloaded operators
T &operator[](size_t idx){return data_[idx];}
T &operator[](size_t idx) const {return data_[idx];}
// basic print function. It may be overriden if, for example, the type <T> is a struct.
void print(const string name="signal"){
cout << endl;
for(size_t i=0; i<size_; ++i){
cout << name << "[" << i << "]\t=\t" << data_[i] << endl;
}
}
};
// This class is a Signal that works on aligned float arrays allocated by FFTW.
// It also overloads some further operators to do basic arithmetic
class FloatSignal : public Signal<float>{
public:
// the basic constructor allocates an aligned, float array, which is zeroed by the superclass
explicit FloatSignal(size_t size)
: Signal(fftwf_alloc_real(size), size){}
explicit FloatSignal(float* data, size_t size) : FloatSignal(size){
memcpy(data_, data, sizeof(float)*size);
}
explicit FloatSignal(float* data, size_t size, size_t pad_bef, size_t pad_aft)
: FloatSignal(size+pad_bef+pad_aft){
memcpy(data_+pad_bef, data, sizeof(float)*size);
}
// the destructor frees the only resource allocated
~FloatSignal() {fftwf_free(data_);}
void operator+=(const float x){for(size_t i=0; i<size_; ++i){data_[i] += x;}}
void operator*=(const float x){for(size_t i=0; i<size_; ++i){data_[i] *= x;}}
void operator/=(const float x){for(size_t i=0; i<size_; ++i){data_[i] /= x;}}
};
// This class is a Signal that works on aligned complex (float[2]) arrays allocated by FFTW.
// It also overloads some further operators to do basic arithmetic
class ComplexSignal : public Signal<fftwf_complex>{
public:
// the basic constructor allocates an aligned, float[2] array, which is zeroed by the superclass
explicit ComplexSignal(size_t size)
: Signal(fftwf_alloc_complex(size), size){}
~ComplexSignal(){fftwf_free(data_);}
void operator*=(const float x){
for(size_t i=0; i<size_; ++i){
data_[i][REAL] *= x;
data_[i][IMAG] *= x;
}
}
void operator+=(const float x){for(size_t i=0; i<size_; ++i){data_[i][REAL] += x;}}
void operator+=(const fftwf_complex x){
for(size_t i=0; i<size_; ++i){
data_[i][REAL] += x[REAL];
data_[i][IMAG] += x[IMAG];
}
}
// override print method to show both fields of the complex number
void print(const string name="signal"){
for(size_t i=0; i<size_; ++i){
printf("%s[%zu]\t=\t(%f, i%f)\n",name.c_str(),i,data_[i][REAL],data_[i][IMAG]);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// This free function takes three complex signals a,b,c of the same size and computes the complex
// element-wise multiplication: a+ib * c+id = ac+iad+ibc-bd = ac-bd + i(ad+bc) The computation
// loop isn't sent to OMP because this function itself is already expected to be called by multiple
// threads, and it would actually slow down the process.
// It throuws an exception if
void SpectralConvolution(const ComplexSignal &a, const ComplexSignal &b, ComplexSignal &result){
const size_t kSize_a = a.getSize();
const size_t kSize_b = b.getSize();
const size_t kSize_result = result.getSize();
CheckAllEqual({kSize_a, kSize_b, kSize_result},
"SpectralConvolution: all sizes must be equal and are");
for(size_t i=0; i<kSize_a; ++i){
// a+ib * c+id = ac+iad+ibc-bd = ac-bd + i(ad+bc)
result[i][REAL] = a[i][REAL]*b[i][REAL] - a[i][IMAG]*b[i][IMAG];
result[i][IMAG] = a[i][IMAG]*b[i][REAL] + a[i][REAL]*b[i][IMAG];
}
}
// This function behaves identically to SpectralConvolution, but computes c=a*conj(b) instead
// of c=a*b: a * conj(b) = a+ib * c-id = ac-iad+ibc+bd = ac+bd + i(bc-ad)
void SpectralCorrelation(const ComplexSignal &a, const ComplexSignal &b, ComplexSignal &result){
const size_t kSize_a = a.getSize();
const size_t kSize_b = b.getSize();
const size_t kSize_result = result.getSize();
CheckAllEqual({kSize_a, kSize_b, kSize_result},
"SpectralCorrelation: all sizes must be equal and are");
for(size_t i=0; i<kSize_a; ++i){
// a * conj(b) = a+ib * c-id = ac-iad+ibc+bd = ac+bd + i(bc-ad)
result[i][REAL] = a[i][REAL]*b[i][REAL] + a[i][IMAG]*b[i][IMAG];
result[i][IMAG] = a[i][IMAG]*b[i][REAL] - a[i][REAL]*b[i][IMAG];
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// This class is a simple wrapper for the memory management of the fftw plans, plus a
// parameterless execute() method which is also a wrapper for FFTW's execute.
// It is not expected to be used directly: rather, to be extended by specific plans, for instance,
// if working with real, 1D signals, only 1D complex<->real plans are needed.
class FftPlan{
private:
fftwf_plan plan_;
public:
explicit FftPlan(fftwf_plan p): plan_(p){}
virtual ~FftPlan(){fftwf_destroy_plan(plan_);}
void execute(){fftwf_execute(plan_);}
};
// This forward plan (1D, R->C) is adequate to process 1D floats (real).
class FftForwardPlan : public FftPlan{
public:
// This constructor creates a real->complex plan that performs the FFT(real) and saves it into the
// complex. As explained in the FFTW docs (http://www.fftw.org/#documentation), the size of
// the complex has to be size(real)/2+1, so the constructor will throw a runtime error if
// this condition doesn't hold. Since the signals and the superclass already have proper
// destructors, no special memory management has to be done.
explicit FftForwardPlan(FloatSignal &fs, ComplexSignal &cs)
: FftPlan(fftwf_plan_dft_r2c_1d(fs.getSize(), fs.getData(), cs.getData(), FFTW_ESTIMATE)){
CheckRealComplexRatio(fs.getSize(), cs.getSize(), "FftForwardPlan");
}
};
// This backward plan (1D, C->R) is adequate to process spectra of 1D floats (real).
class FftBackwardPlan : public FftPlan{
public:
// This constructor creates a complex->real plan that performs the IFFT(complex) and saves it
// complex. As explained in the FFTW docs (http://www.fftw.org/#documentation), the size of
// the complex has to be size(real)/2+1, so the constructor will throw a runtime error if
// this condition doesn't hold. Since the signals and the superclass already have proper
// destructors, no special memory management has to be done.
explicit FftBackwardPlan(ComplexSignal &cs, FloatSignal &fs)
: FftPlan(fftwf_plan_dft_c2r_1d(fs.getSize(), cs.getData(), fs.getData(), FFTW_ESTIMATE)){
CheckRealComplexRatio(fs.getSize(), cs.getSize(), "FftBackwardPlan");
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// This function is a small script that calculates the FFT wisdom for all powers of two (since those
// are the only expected sizes to be used with the FFTs), and exports it to the given path. The
// wisdom is a brute-force search of the most efficient implementations for the FFTs: It takes a
// while to compute, but has to be done only once (per computer), and then it can be quickly loaded
// for faster FFT computation, as explained in the docs (http://www.fftw.org/#documentation).
// See also the docs for different flags. Note that using a wisdom file is optional.
void MakeAndExportFftwWisdom(const string path_out, const size_t min_2pow=0,
const size_t max_2pow=25, const unsigned flag=FFTW_PATIENT){
for(size_t i=min_2pow; i<=max_2pow; ++i){
size_t size = pow(2, i);
FloatSignal fs(size);
ComplexSignal cs(size/2+1);
printf("creating forward and backward plans for size=2**%zu=%zu and flag %u...\n", i, size, flag);
FftForwardPlan fwd(fs, cs);
FftBackwardPlan bwd(cs, fs);
}
fftwf_export_wisdom_to_filename(path_out.c_str());
}
// Given a path to a wisdom file generated with "MakeAndExportFftwWisdom", reads and loads it
// into FFTW to perform faster FFT computations. Using a wisdom file is optional.
void ImportFftwWisdom(const string path_in, const bool throw_exception_if_fail=true){
int result = fftwf_import_wisdom_from_filename(path_in.c_str());
if(result!=0){
cout << "[ImportFftwWisdom] succesfully imported " << path_in << endl;
} else{
string message = "[ImportFftwWisdom] ";
message += "couldn't import wisdom! is this a path to a valid wisdom file? -->"+path_in+"<--\n";
if(throw_exception_if_fail){throw runtime_error(string("ERROR: ") + message);}
else{cout << "WARNING: " << message;}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
/// PERFORM CONVOLUTION/CORRELATION
////////////////////////////////////////////////////////////////////////////////////////////////////
// This class performs an efficient version of the spectral convolution/cross-correlation between
// two 1D float arrays, <SIGNAL> and <PATCH>, called overlap-save:
// http://www.comm.utoronto.ca/~dkundur/course_info/real-time-DSP/notes/8_Kundur_Overlap_Save_Add.pdf
// This algorithm requires that the length of <PATCH> is less or equal the length of <SIGNAL>,
// so an exception is thrown otherwise. The algorithm works as follows:
// given signal of length S and patch of length P, and being the conv (or xcorr) length U=S+P-1
// 1. pad the patch to X = 2*Pow2Ceil(P). FFTs with powers of 2 are the fastest.
// 2. cut the signal into chunks of size X, with an overlapping section of L=X-(P-1).
// for that, pad the signal with (P-1) before, and with (X-U%L) after, to make it fit exactly.
// 3. Compute the forward FFT of the padded patch and of every chunk of the signal
// 4. Multiply the FFT of the padded patch with every signal chunk.
// 4a. If the operation is a convolution, perform a complex a*b multiplication
// 4b. If the operation is a cross-correlation, perform a complex a*conj(b) multiplication
// 5. Compute the inverse FFT of every result of step 4
// 6. Concatenate the resulting chunks, ignoring (P-1) samples per chunk
// Note that steps 3,4,5 may be parallelized with some significant gain in performance.
// In this class: X = result_chunksize, L = result_stride
class OverlapSaveConvolver {
private:
// grab input lengths
size_t signal_size_;
size_t patch_size_;
size_t result_size_;
// make padded copies of the inputs and get chunk measurements
FloatSignal padded_patch_;
size_t result_chunksize_;
size_t result_chunksize_complex_;
size_t result_stride_;
ComplexSignal padded_patch_complex_;
// padded copy of the signal
FloatSignal padded_signal_;
// the deconstructed signal
vector<FloatSignal*> s_chunks_;
vector<ComplexSignal*> s_chunks_complex_;
// the corresponding chunks holding convs/xcorrs
vector<FloatSignal*> result_chunks_;
vector<ComplexSignal*> result_chunks_complex_;
// the corresponding plans (plus the plan of the patch)
vector<FftForwardPlan*> forward_plans_;
vector<FftBackwardPlan*> backward_plans_;
// Basic state management to prevent getters from being called prematurely.
// Also to adapt the extractResult getter, since Conv and Xcorr padding behaves differently
enum class State {kUninitialized, kConv, kXcorr};
State _state_; // kUninitialized after instantiation, kConv/kXcorr after respective op.
// This private method throws an exception if _state_ is kUninitialized, because that
// means that some "getter" has ben called before any computation has been performed.
void __check_last_executed_not_null(const string method_name){
if(_state_ == State::kUninitialized){
throw runtime_error(string("[ERROR] OverlapSaveConvolver.") + method_name +
"() can't be called before executeXcorr() or executeConv()!" +
" No meaningful data has been computed yet.");
}
}
// This private method implements steps 3,4,5 of the algorithm. If the given flag is false,
// it will perform a convolution (4a), and a cross-correlation (4b) otherwise.
// Note the parallelization with OpenMP, which increases performance in supporting CPUs.
void __execute(const bool cross_correlate){
auto operation = (cross_correlate)? SpectralCorrelation : SpectralConvolution;
// do ffts
#ifdef WITH_OPENMP_ABOVE
#pragma omp parallel for schedule(static, WITH_OPENMP_ABOVE)
#endif
for (size_t i =0; i<forward_plans_.size();i++){
forward_plans_.at(i)->execute();
}
// multiply spectra
#ifdef WITH_OPENMP_ABOVE
#pragma omp parallel for schedule(static, WITH_OPENMP_ABOVE)
#endif
for (size_t i =0; i<result_chunks_.size();i++){
operation(*s_chunks_complex_.at(i), this->padded_patch_complex_, *result_chunks_complex_.at(i));
}
// do iffts
#ifdef WITH_OPENMP_ABOVE
#pragma omp parallel for schedule(static, WITH_OPENMP_ABOVE)
#endif
for (size_t i =0; i<result_chunks_.size();i++){
backward_plans_.at(i)->execute();
*result_chunks_.at(i) /= result_chunksize_;
}
}
public:
// The only constructor for the class, receives two signals and performs steps 1 and 2 of the
// algorithm on them. The signals are passed by reference but the class works with padded copies
// of them, so no care has to be taken regarding memory management.
// The wisdomPath may be empty, or a path to a valid wisdom file.
// Note that len(signal) can never be smaller than len(patch), or an exception is thrown.
OverlapSaveConvolver(FloatSignal &signal, FloatSignal &patch, const string wisdomPath="")
: signal_size_(signal.getSize()),
patch_size_(patch.getSize()),
result_size_(signal_size_+patch_size_-1),
//
padded_patch_(patch.getData(), patch_size_, 0, 2*Pow2Ceil(patch_size_)-patch_size_),
result_chunksize_(padded_patch_.getSize()),
result_chunksize_complex_(result_chunksize_/2+1),
result_stride_(result_chunksize_-patch_size_+1),
padded_patch_complex_(result_chunksize_complex_),
//
padded_signal_(signal.getData(),signal_size_,patch_size_-1, result_chunksize_-(result_size_%result_stride_)),
_state_(State::kUninitialized){
// end of initializer list, now check that len(signal)>=len(patch)
check_a_less_equal_b(patch_size_, signal_size_,
"OverlapSaveConvolver: len(signal) can't be smaller than len(patch)!");
// and load the wisdom if required. If unsuccessful, no exception thrown, just print a warning.
if(!wisdomPath.empty()){ImportFftwWisdom(wisdomPath, false);}
// chunk the signal into strides of same size as padded patch
// and make complex counterparts too, as well as the corresponding xcorr signals
for(size_t i=0; i<=padded_signal_.getSize()-result_chunksize_; i+=result_stride_){
s_chunks_.push_back(new FloatSignal(&padded_signal_[i], result_chunksize_));
s_chunks_complex_.push_back(new ComplexSignal(result_chunksize_complex_));
result_chunks_.push_back(new FloatSignal(result_chunksize_));
result_chunks_complex_.push_back(new ComplexSignal(result_chunksize_complex_));
}
// make one forward plan per signal chunk, and one for the patch
// Also backward plans for the xcorr chunks
forward_plans_.push_back(new FftForwardPlan(padded_patch_, padded_patch_complex_));
for (size_t i =0; i<s_chunks_.size();i++){
forward_plans_.push_back(new FftForwardPlan(*s_chunks_.at(i), *s_chunks_complex_.at(i)));
backward_plans_.push_back(new FftBackwardPlan(*result_chunks_complex_.at(i),
*result_chunks_.at(i)));
}
}
//
void executeConv(){
__execute(false);
_state_ = State::kConv;
}
void executeXcorr(){
__execute(true);
_state_ = State::kXcorr;
}
// getting info from the convolfer
void printChunks(const string name="convolver"){
__check_last_executed_not_null("printChunks");
for (size_t i =0; i<result_chunks_.size();i++){
result_chunks_.at(i)->print(name+"_chunk_"+to_string(i));
}
}
// This method implements step 6 of the overlap-save algorithm. In convolution, the first (P-1)
// samples of each chunk are discarded, in xcorr the last (P-1) ones. Therefore, depending on the
// current _state_, the corresponding method is used. USAGE:
// Every time it is called, this function returns a new FloatSignal instance of size
// len(signal)+len(patch)-1. If the last operation performed was executeConv(), this function
// will return the convolution of signal and patch. If the last operation performed was
// executeXcorr(), the result will contain the cross-correlation. If none of them was performed
// at the moment of calling this function, an exception will be thrown.
// The indexing will start with the most negative relation, and increase accordingly. Which means:
// given S:=len(signal), P:=len(patch), T:=S+P-1
// for 0 <= i < T, result[i] will hold dot_product(patch, signal[i-(P-1) : i])
// where patch will be "reversed" if the convolution was performed. For example:
// Signal := [1 2 3 4 5 6 7] Patch = [1 1 1]
// Result[0] = [1 1 1] => 1*1 = 1 // FIRST ENTRY
// Result[1] = [1 1 1] => 1*1+1*2 = 3
// Result[2] = [1 1 1] => 1*1+1*2+1*3 = 8 // FIRST NON-NEG ENTRY AT P-1
// ...
// Result[8] = [1 1 1] => 1*7 = 7 // LAST ENTRY
// Note that the returned signal object takes care of its own memory, so no management is needed.
FloatSignal extractResult(){
// make sure that an operation was called before
__check_last_executed_not_null("extractResult");
// set the offset for the corresponding operation (0 for xcorr).
size_t discard_offset = 0;
if(_state_==State::kConv){discard_offset = result_chunksize_ - result_stride_;}
// instantiate new signal to be filled with the desired info
FloatSignal result(result_size_);
float* result_arr = result.getData(); // not const because of memcpy
// fill!
static size_t kNumChunks = result_chunks_.size();
for (size_t i=0; i<kNumChunks;i++){
float* xc_arr = result_chunks_.at(i)->getData();
const size_t kBegin = i*result_stride_;
// if the last chunk goes above result_size_, reduce copy size. else copy_size=result_stride_
size_t copy_size = result_stride_;
copy_size -= (kBegin+result_stride_>result_size_)? kBegin+result_stride_-result_size_ : 0;
memcpy(result_arr+kBegin, xc_arr+discard_offset, sizeof(float)*copy_size);
}
return result;
}
~OverlapSaveConvolver(){
// clear vectors holding signals
for (size_t i =0; i<s_chunks_.size();i++){
delete (s_chunks_.at(i));
delete (s_chunks_complex_.at(i));
delete (result_chunks_.at(i));
delete (result_chunks_complex_.at(i));
}
s_chunks_.clear();
s_chunks_complex_.clear();
result_chunks_.clear();
result_chunks_complex_.clear();
// clear vector holding forward FFT plans
for (size_t i =0; i<forward_plans_.size();i++){
delete (forward_plans_.at(i));
}
forward_plans_.clear();
// clear vector holding backward FFT plans
for (size_t i =0; i<backward_plans_.size();i++){
delete (backward_plans_.at(i));
}
backward_plans_.clear();
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/// MAIN ROUTINE
////////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char** argv){
// // do this just once to configure your system for an optimal FFT
// const string kWisdomPatient = "wisdom_real_dft_pow2_patient";
// MakeAndExportFftwWisdom(kWisdomPatient, 0, 29, FFTW_PATIENT);
// create a test signal
const size_t kSizeS = 10; //44100*10;
float* s_arr = new float[kSizeS]; for(size_t i=0; i<kSizeS; ++i){s_arr[i] = i+1;}
FloatSignal s(s_arr, kSizeS);
s.print("signal");
// create a test patch
const size_t kSizeP = 2;// 44100*1;
float* p_arr = new float[kSizeP]; for(size_t i=0; i<kSizeP; ++i){p_arr[i]=i+1;}
FloatSignal p(p_arr, kSizeP);
p.print("patch");
// Instantiate convolver with both signals (p can't be bigger than s)
OverlapSaveConvolver x(s, p);
// CONV
x.executeConv();
// x.printChunks("conv");
x.extractResult().print("conv");
// XCORR
x.executeXcorr();
// x.printChunks("xcorr");
x.extractResult().print("xcorr");
// clean memory and exit
delete[] s_arr;
delete[] p_arr;
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment