Skip to content

Instantly share code, notes, and snippets.

@mikesol
Last active April 13, 2024 00:13
Show Gist options
  • Save mikesol/e9c657944509aa75b6473fe7d572f075 to your computer and use it in GitHub Desktop.
Save mikesol/e9c657944509aa75b6473fe7d572f075 to your computer and use it in GitHub Desktop.
Ring Buffer for our ONNX rig
Copyright 2023 Mike Solomon
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
/*
==============================================================================
This file contains the basic framework code for a JUCE plugin processor.
==============================================================================
*/
#include "ONNXificator.h"
#include <iostream>
#include <onnxruntime_cxx_api.h>
#include <JuceHeader.h>
#include "RingBuffer.h"
const char* const input_names[] = { "input" };
const char* const output_names[] = { "output" };
const int internal_buffer_size = 44100*2;
static float lots_of_zeros[internal_buffer_size];
ONNXificator::ONNXificator(const char* onnx, const int onnxSize):
session(env, onnx, onnxSize, Ort::SessionOptions()),
memory_info(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU)),
input_shape(session.GetInputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()),
output_shape(session.GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()),
input_audio_size(static_cast<size_t>(1 * 1) * input_shape[2]),
output_audio_size(static_cast<size_t>(1 * 1) * output_shape[2]),
output_audio_stride(output_audio_size / 2),
latency_samples(input_audio_size - output_audio_size),
input_audio(new float[input_audio_size]),
output_audio(new float[output_audio_size]),
input_tensor(Ort::Value::CreateTensor<float>(memory_info, input_audio, input_audio_size, input_shape.data(), input_shape.size())),
output_tensor( Ort::Value::CreateTensor<float>(memory_info, output_audio, output_audio_size, output_shape.data(), output_shape.size())),
window(output_audio_size, juce::dsp::WindowingFunction<float>::hann, false)
{
}
ONNXificator::~ONNXificator()
{
delete[] input_audio;
delete[] output_audio;
}
int ONNXificator::getLatencySamples() {
return latency_samples;
}
void ONNXificator::go(juce::AudioBuffer<float>& buffer, int* total_samples_processed, int* latency_samples, RingBuffer<float>& audio_to_process, RingBuffer<float>& processed_audio, RingBuffer<float>* ground_truth, juce::AudioBuffer<float>* groundBuffer)
{
auto addFunctionPtr = static_cast<void(*)(float*, const float*, std::size_t)>(&juce::FloatVectorOperations::add);
// 1. getNSamples
// 2. write N samples onto the last part of the buffer
// 3. advance the write-head by N samples
// 4. determine if the analysis head has enough samples to run inference
// 4a. copy to inference input buffer
// 4b. run inference
// 4c. do windowing on the inference
// 4d. add inference onto the write-head of the buffer
// 4e. add stride-worth of zeros for future additions
// 4f. advance the audio_to_process read head by stride amount
// 4g. advance the processed_audio write head by the stride amount
// 4h. increase the total number of samples processed
// 5. write max(0, min(nSamples, latency_samples-total_samples_written)) to the output
// 6. write (nSamples - zeros_written) samples from processed_audio
// 7. shift the samples and the write head to the left by the number of samples written
// [1]
int nSamples = buffer.getNumSamples();
if (buffer.getNumChannels() <= 0) { return; }
// [2]
audio_to_process.asDestSizeOf(memcpy, buffer.getReadPointer(0), nSamples);
// [3]
audio_to_process.advanceWriteHead(nSamples);
// [4]
while (audio_to_process.dist() >= input_audio_size) {
// [4a]
audio_to_process.asSourceSizeOf(memcpy, input_audio, input_audio_size);
// [4b]
session.Run(Ort::RunOptions(), input_names, &input_tensor, 1, output_names, &output_tensor, 1);
// [4c]
window.multiplyWithWindowingTable (output_audio, output_audio_size);
// [4d]
processed_audio.asDest(addFunctionPtr, output_audio, output_audio_size);
// [4e]
processed_audio.advanceWriteHead(output_audio_size);
processed_audio.asDestSizeOf(memcpy, lots_of_zeros, output_audio_stride);
processed_audio.rewindWriteHead(output_audio_size);
// ground truth
if (ground_truth != nullptr) {
// we fast-forward to the end of the input audio\
// under the assumption that our models are
// _always_ causal, meaning they always capture the
// end of the signal. if that changes, change this line!
auto input_audio_pos = input_audio + (input_audio_size - output_audio_size);
window.multiplyWithWindowingTable (input_audio_pos, output_audio_size);
ground_truth->asDest(addFunctionPtr, input_audio_pos, output_audio_size);
ground_truth->advanceWriteHead(output_audio_size);
ground_truth->asDestSizeOf(memcpy, lots_of_zeros, output_audio_stride);
ground_truth->rewindWriteHead(output_audio_size);
ground_truth->advanceWriteHead(output_audio_stride);
}
// [4f]
audio_to_process.advanceReadHead(output_audio_stride);
// [4g]
processed_audio.advanceWriteHead(output_audio_stride);
// [4h]
*total_samples_processed += output_audio_stride;
}
// [5]
const int zerosToWrite = std::max(0, std::min(nSamples, *latency_samples-*total_samples_processed));
buffer.copyFrom(0, 0, lots_of_zeros, zerosToWrite);
if (groundBuffer != nullptr) {
groundBuffer->copyFrom(0, 0, lots_of_zeros, zerosToWrite);
}
// [6]
const int samples_to_write = nSamples - zerosToWrite;
for (int channel = 0; channel < buffer.getNumChannels(); channel++) {
processed_audio.asSourceSizeOf(memcpy, buffer.getWritePointer(channel), nSamples);
}
if (groundBuffer != nullptr && ground_truth != nullptr) {
for (int channel = 0; channel < groundBuffer->getNumChannels(); channel++) {
ground_truth->asSourceSizeOf(memcpy, groundBuffer->getWritePointer(channel), nSamples);
}
}
// [7]
processed_audio.advanceReadHead(samples_to_write);
if (ground_truth != nullptr) {
ground_truth->advanceReadHead(samples_to_write);
}
}
/*
==============================================================================
This file contains the basic framework code for a JUCE plugin processor.
==============================================================================
*/
#pragma once
#include <onnxruntime_cxx_api.h>
#include <JuceHeader.h>
#include "RingBuffer.h"
//==============================================================================
/**
*/
class ONNXificator
{
public:
//==============================================================================
ONNXificator(const char* onnx, const int onnxSize);
~ONNXificator();
int getLatencySamples();
void go (juce::AudioBuffer<float>& buffer, int* total_samples_processed, int* latency_samples, RingBuffer<float>& audio_to_process, RingBuffer<float>& processed_audio, RingBuffer<float> *ground_truth = nullptr, juce::AudioBuffer<float> *ground_buffer = nullptr);
private:
//==============================================================================
Ort::Env env;
Ort::Session session;
Ort::MemoryInfo memory_info;
std::vector<int64_t> input_shape;
std::vector<int64_t> output_shape;
size_t input_audio_size;
size_t output_audio_size;
size_t output_audio_stride;
int latency_samples;
float *const input_audio;
float *const output_audio;
Ort::Value input_tensor;
Ort::Value output_tensor;
juce::dsp::WindowingFunction<float> window;
//==============================================================================
JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR (ONNXificator)
};
//
// RingBuffer.h
// ONNXTest0
//
// Created by Mike Solomon on 22.9.2023.
//
#ifndef RingBuffer_h
#define RingBuffer_h
#include <cstddef> // For std::size_t
template<typename T>
class RingBuffer {
public:
// Constructor to initialize buffer size
explicit RingBuffer(std::size_t size)
: size_(size),
read_head(0),
write_head(0)
{
buffer = new T[size];
}
// Destructor to clean up dynamic memory
~RingBuffer() {
delete[] buffer;
}
void resetHeads() {
read_head = 0;
write_head = 0;
}
void advanceReadHead(std::size_t amount) {
read_head = (read_head + amount) % size_;
}
void advanceWriteHead(std::size_t amount) {
write_head = (write_head + amount) % size_;
}
void rewindReadHead(std::size_t amount) {
read_head = (read_head + size_ - amount) % size_;
}
void rewindWriteHead(std::size_t amount) {
write_head = (write_head + size_ - amount) % size_;
}
template<typename Func>
void asSourceSizeOf(Func fn, T* dest, std::size_t n) {
asSource_(fn, dest, n, sizeof(T));
}
template<typename Func>
void asDestSizeOf(Func fn, const T* src, std::size_t n) {
asDest_(fn, src, n, sizeof(T));
}
template<typename Func>
void asSource(Func fn, T* dest, std::size_t n) {
asSource_(fn, dest, n, 1);
}
template<typename Func>
void asDest(Func fn, const T* src, std::size_t n) {
asDest_(fn, src, n, 1);
}
std::size_t dist() const {
return (write_head >= read_head) ? (write_head - read_head) : (size_ + write_head - read_head);
}
private:
// Function applying fn on the buffer without moving read_head
template<typename Func>
void asSource_(Func fn, T* dest, std::size_t n, std::size_t mult) {
std::size_t tmp_read_head = read_head;
std::size_t processed = 0;
while (processed < n) {
const std::size_t to_process = std::min(n - processed, size_ - tmp_read_head);
fn(dest + processed, buffer + tmp_read_head, to_process * mult);
processed += to_process;
tmp_read_head = (tmp_read_head + to_process) % size_;
}
}
// Function applying fn on the buffer without moving write_head
template<typename Func>
void asDest_(Func fn, const T* src, std::size_t n, std::size_t mult) {
std::size_t tmp_write_head = write_head;
std::size_t processed = 0;
while (processed < n) {
const std::size_t to_process = std::min(n - processed, size_ - tmp_write_head);
fn(buffer + tmp_write_head, src + processed, to_process * mult);
processed += to_process;
tmp_write_head = (tmp_write_head + to_process) % size_;
}
}
private:
T* buffer; // Internal array of templated type
std::size_t size_; // Size of the buffer
std::size_t read_head; // Read head position
std::size_t write_head; // Write head position
};
#endif /* RingBuffer_h */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment