Skip to content

Instantly share code, notes, and snippets.

@dfeneyrou
Created June 26, 2021 19:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dfeneyrou/8f8aa2956dcb32e0860665c610d7bb2f to your computer and use it in GitHub Desktop.
Save dfeneyrou/8f8aa2956dcb32e0860665c610d7bb2f to your computer and use it in GitHub Desktop.
Adaptation of FiberTaskingLib to Palanteer: example + upgrade of 2 files, task_scheduler & callback, to know when a fiber is suspended
/**
* FiberTaskingLib - A tasking library that uses fibers for efficient task switching
*
* This library was created as a proof of concept of the ideas presented by
* Christian Gyrling in his 2015 GDC Talk 'Parallelizing the Naughty Dog Engine Using Fibers'
*
* http://gdcvault.com/play/1022186/Parallelizing-the-Naughty-Dog-Engine
*
* FiberTaskingLib is the legal property of Adrian Astley
* Copyright Adrian Astley 2015 - 2018
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace ftl {
enum class FiberState : int {
// The fiber has started execution on a worker thread
Attached,
// The fiber is no longer being executed on a worker thread
Detached
};
using ThreadCreationCallback = void (*)(void *context, unsigned threadCount);
using FiberCreationCallback = void (*)(void *context, unsigned fiberCount);
using ThreadEventCallback = void (*)(void *context, unsigned threadIndex);
using FiberEventCallback = void (*)(void *context, unsigned fiberIndex, FiberState newState, bool isSuspended);
struct EventCallbacks {
void *Context = nullptr;
ThreadCreationCallback OnThreadsCreated = nullptr;
FiberCreationCallback OnFibersCreated = nullptr;
ThreadEventCallback OnWorkerThreadStarted = nullptr;
ThreadEventCallback OnWorkerThreadEnded = nullptr;
FiberEventCallback OnFiberStateChanged = nullptr;
};
} // End of namespace ftl
#include "ftl/task_counter.h"
#include "ftl/task_scheduler.h"
#define USE_PL 1
#define PL_VIRTUAL_THREADS 1
#define PL_IMPLEMENTATION 1
#define PL_IMPL_COLLECTION_BUFFER_BYTE_QTY 20000000
#include "palanteer.h"
#include <assert.h>
#include <stdint.h>
// ===========================
// Fiber support for Palanteer
// ===========================
void fiberWorkerThreadStarted(void *context, unsigned threadIndex)
{
char tmpStr[64];
snprintf(tmpStr, sizeof(tmpStr), "Workers/Fiber worker %d", threadIndex);
plDeclareThreadDyn(tmpStr);
plMarkerDyn("thread start", tmpStr);
}
void fiberCreated(void *context, unsigned fiberCount)
{
char tmpStr[64];
for(int fiberId=0; fiberId<fiberCount; ++fiberId) {
snprintf(tmpStr, sizeof(tmpStr), "Fibers/Fiber %d", fiberId+1);
plDeclareVirtualThread(fiberId, tmpStr);
}
}
void fiberStateChanged(void *context, unsigned fiberIndex, ftl::FiberState newState, bool isSuspended)
{
if(newState==ftl::FiberState::Attached) {
plAttachVirtualThread(fiberIndex);
} else {
plDetachVirtualThread(isSuspended);
}
}
// ===========================
// Test program
// ===========================
constexpr static unsigned kNumProducerTasks = 100U;
constexpr static unsigned kNumConsumerTasks = 10000U;
void Consumer(ftl::TaskScheduler * /*scheduler*/, void *arg)
{
plBegin("Consumer");
auto *globalCounter = reinterpret_cast<std::atomic<unsigned> *>(arg);
globalCounter->fetch_add(1);
plEnd("Consumer");
}
void Producer(ftl::TaskScheduler *taskScheduler, void *arg)
{
plScope("Producer");
plBegin("Creating consumer tasks");
auto *tasks = new ftl::Task[kNumConsumerTasks];
for (unsigned i = 0; i < kNumConsumerTasks; ++i) {
tasks[i] = {Consumer, arg};
}
plEnd("Creating consumer tasks");
plBegin("Inserting tasks");
ftl::TaskCounter counter(taskScheduler);
taskScheduler->AddTasks(kNumConsumerTasks, tasks, ftl::TaskPriority::Low, &counter);
delete[] tasks;
plEnd("Inserting tasks");
plBegin("Waiting...");
taskScheduler->WaitForCounter(&counter);
plEnd("");
}
int main(int argc, char** argv)
{
plInitAndStart("Fiber Palanteer test");
plDeclareThread("Main");
{
// Create the task scheduler and bind the main thread to it
ftl::TaskScheduler taskScheduler;
ftl::TaskSchedulerInitOptions opt;
opt.FiberPoolSize = 10;
opt.ThreadPoolSize = 3;
opt.Callbacks.OnWorkerThreadStarted = fiberWorkerThreadStarted;
opt.Callbacks.OnFibersCreated = fiberCreated;
opt.Callbacks.OnFiberStateChanged = fiberStateChanged;
taskScheduler.Init(opt);
std::atomic<unsigned> globalCounter(0U);
std::array<ftl::Task, kNumProducerTasks> tasks{};
for (auto &&task : tasks) {
task = {Producer, &globalCounter};
}
ftl::TaskCounter counter(&taskScheduler);
taskScheduler.AddTasks(kNumProducerTasks, tasks.data(), ftl::TaskPriority::Low, &counter);
taskScheduler.WaitForCounter(&counter);
// Test to see that all tasks finished
assert(kNumProducerTasks * kNumConsumerTasks == globalCounter.load());
}
plStopAndUninit();
return 0;
}
/**
* FiberTaskingLib - A tasking library that uses fibers for efficient task switching
*
* This library was created as a proof of concept of the ideas presented by
* Christian Gyrling in his 2015 GDC Talk 'Parallelizing the Naughty Dog Engine Using Fibers'
*
* http://gdcvault.com/play/1022186/Parallelizing-the-Naughty-Dog-Engine
*
* FiberTaskingLib is the legal property of Adrian Astley
* Copyright Adrian Astley 2015 - 2018
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ftl/task_scheduler.h"
#include "ftl/atomic_counter.h"
#include "ftl/callbacks.h"
#include "ftl/task_counter.h"
#include "ftl/thread_abstraction.h"
#if defined(FTL_WIN32_THREADS)
# ifndef WIN32_LEAN_AND_MEAN
# define WIN32_LEAN_AND_MEAN
# endif
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
#elif defined(FTL_POSIX_THREADS)
# include <pthread.h>
#endif
namespace ftl {
constexpr static unsigned kFailedPopAttemptsHeuristic = 5;
constexpr static int kInitErrorDoubleCall = -30;
constexpr static int kInitErrorFailedToCreateWorkerThread = -60;
struct ThreadStartArgs {
TaskScheduler *Scheduler;
unsigned ThreadIndex;
};
FTL_THREAD_FUNC_RETURN_TYPE TaskScheduler::ThreadStartFunc(void *const arg) {
auto *const threadArgs = reinterpret_cast<ThreadStartArgs *>(arg);
TaskScheduler *taskScheduler = threadArgs->Scheduler;
unsigned const index = threadArgs->ThreadIndex;
// Clean up
delete threadArgs;
// Spin wait until everything is initialized
while (!taskScheduler->m_initialized.load(std::memory_order_acquire)) {
// Spin
FTL_PAUSE();
}
// Execute user thread start callback, if set
const EventCallbacks &callbacks = taskScheduler->m_callbacks;
if (callbacks.OnWorkerThreadStarted != nullptr) {
callbacks.OnWorkerThreadStarted(callbacks.Context, index);
}
// Get a free fiber to switch to
unsigned const freeFiberIndex = taskScheduler->GetNextFreeFiberIndex();
// Initialize tls
taskScheduler->m_tls[index].CurrentFiberIndex = freeFiberIndex;
// Switch
taskScheduler->m_tls[index].ThreadFiber.SwitchToFiber(&taskScheduler->m_fibers[freeFiberIndex]);
// And we've returned
// Execute user thread end callback, if set
if (callbacks.OnWorkerThreadEnded != nullptr) {
callbacks.OnWorkerThreadEnded(callbacks.Context, index);
}
// Cleanup and shutdown
EndCurrentThread();
FTL_THREAD_FUNC_END;
}
// This Task is never used directly
// However, a function pointer to it is the signal that the task is a Ready fiber, not a "real" task
// See @FiberStartFunc() for more details
static void ReadyFiberDummyTask(TaskScheduler *taskScheduler, void *arg) {
(void)taskScheduler;
(void)arg;
}
void TaskScheduler::FiberStartFunc(void *const arg) {
TaskScheduler *taskScheduler = reinterpret_cast<TaskScheduler *>(arg);
if (taskScheduler->m_callbacks.OnFiberStateChanged != nullptr) {
taskScheduler->m_callbacks.OnFiberStateChanged(taskScheduler->m_callbacks.Context, taskScheduler->GetCurrentFiberIndex(), ftl::FiberState::Attached, false);
}
// If we just started from the pool, we may need to clean up from another fiber
taskScheduler->CleanUpOldFiber();
std::vector<TaskBundle> taskBuffer;
// Process tasks infinitely, until quit
while (!taskScheduler->m_quit.load(std::memory_order_acquire)) {
unsigned waitingFiberIndex = kInvalidIndex;
ThreadLocalStorage *tls = &taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
bool readyWaitingFibers = false;
// Check if there is a ready pinned waiting fiber
{
std::lock_guard<std::mutex> guard(tls->PinnedReadyFibersLock);
for (auto bundle = tls->PinnedReadyFibers.begin(); bundle != tls->PinnedReadyFibers.end(); ++bundle) {
readyWaitingFibers = true;
if (!(*bundle)->FiberIsSwitched.load(std::memory_order_acquire)) {
// The wait condition is ready, but the "source" thread hasn't switched away from the fiber yet
// Skip this fiber until the next round
continue;
}
waitingFiberIndex = (*bundle)->FiberIndex;
tls->PinnedReadyFibers.erase(bundle);
break;
}
}
TaskBundle nextTask{};
bool foundTask = false;
// If nothing was found, check if there is a high priority task to run
if (waitingFiberIndex == kInvalidIndex) {
foundTask = taskScheduler->GetNextHiPriTask(&nextTask, &taskBuffer);
// Check if the found task is a ReadyFiber dummy task
if (foundTask && nextTask.TaskToExecute.Function == ReadyFiberDummyTask) {
// Get the waiting fiber index
ReadyFiberBundle *readyFiberBundle = reinterpret_cast<ReadyFiberBundle *>(nextTask.TaskToExecute.ArgData);
waitingFiberIndex = readyFiberBundle->FiberIndex;
}
}
if (waitingFiberIndex != kInvalidIndex) {
// Found a waiting task that is ready to continue
tls->OldFiberIndex = tls->CurrentFiberIndex;
tls->CurrentFiberIndex = waitingFiberIndex;
tls->OldFiberDestination = FiberDestination::ToPool;
const EventCallbacks &callbacks = taskScheduler->m_callbacks;
if (callbacks.OnFiberStateChanged != nullptr) {
callbacks.OnFiberStateChanged(callbacks.Context, tls->OldFiberIndex, FiberState::Detached, false);
}
// Switch
taskScheduler->m_fibers[tls->OldFiberIndex].SwitchToFiber(&taskScheduler->m_fibers[tls->CurrentFiberIndex]);
if (callbacks.OnFiberStateChanged != nullptr) {
callbacks.OnFiberStateChanged(callbacks.Context, taskScheduler->GetCurrentFiberIndex(), FiberState::Attached, true);
}
// And we're back
taskScheduler->CleanUpOldFiber();
// Get a fresh instance of TLS, since we could be on a new thread now
tls = &taskScheduler->m_tls[taskScheduler->GetCurrentThreadIndex()];
if (taskScheduler->m_emptyQueueBehavior.load(std::memory_order::memory_order_relaxed) == EmptyQueueBehavior::Sleep) {
tls->FailedQueuePopAttempts = 0;
}
} else {
// If we didn't find a high priority task, look for a low priority task
if (!foundTask) {
foundTask = taskScheduler->GetNextLoPriTask(&nextTask);
}
EmptyQueueBehavior const behavior = taskScheduler->m_emptyQueueBehavior.load(std::memory_order::memory_order_relaxed);
if (foundTask) {
if (behavior == EmptyQueueBehavior::Sleep) {
tls->FailedQueuePopAttempts = 0;
}
nextTask.TaskToExecute.Function(taskScheduler, nextTask.TaskToExecute.ArgData);
if (nextTask.Counter != nullptr) {
nextTask.Counter->Decrement();
}
} else {
// We failed to find a Task from any of the queues
// What we do now depends on m_emptyQueueBehavior, which we loaded above
switch (behavior) {
case EmptyQueueBehavior::Yield:
YieldThread();
break;
case EmptyQueueBehavior::Sleep: {
// If we have a ready waiting fiber, prevent sleep
if (!readyWaitingFibers) {
++tls->FailedQueuePopAttempts;
// Go to sleep if we've failed to find a task kFailedPopAttemptsHeuristic times
if (tls->FailedQueuePopAttempts >= kFailedPopAttemptsHeuristic) {
std::unique_lock<std::mutex> lock(taskScheduler->ThreadSleepLock);
// Acquire the pinned ready fibers lock here and check if there are any pinned fibers ready
// Acquiring the lock here prevents a race between readying a pinned fiber (on another thread) and going to sleep
// Either this thread wins, then notify_*() will wake it
// Or the other thread wins, then this thread will observe the pinned fiber, and will not go to sleep
std::unique_lock<std::mutex> readyfiberslock(tls->PinnedReadyFibersLock);
if (tls->PinnedReadyFibers.empty()) {
// Unlock before going to sleep (the other lock is released by the CV wait)
readyfiberslock.unlock();
taskScheduler->ThreadSleepCV.wait(lock);
}
tls->FailedQueuePopAttempts = 0;
}
}
break;
}
case EmptyQueueBehavior::Spin:
default:
// Just fall through and continue the next loop
break;
}
}
}
}
// Switch to the quit fibers
if (taskScheduler->m_callbacks.OnFiberStateChanged != nullptr) {
taskScheduler->m_callbacks.OnFiberStateChanged(taskScheduler->m_callbacks.Context, taskScheduler->GetCurrentFiberIndex(), ftl::FiberState::Detached, false);
}
unsigned index = taskScheduler->GetCurrentThreadIndex();
taskScheduler->m_fibers[taskScheduler->m_tls[index].CurrentFiberIndex].SwitchToFiber(&taskScheduler->m_quitFibers[index]);
// We should never get here
printf("Error: FiberStart should never return");
}
void TaskScheduler::ThreadEndFunc(void *arg) {
TaskScheduler *taskScheduler = reinterpret_cast<TaskScheduler *>(arg);
// Wait for all other threads to quit
taskScheduler->m_quitCount.fetch_add(1, std::memory_order_seq_cst);
while (taskScheduler->m_quitCount.load(std::memory_order_seq_cst) != taskScheduler->m_numThreads) {
SleepThread(50);
}
// Jump to the thread fibers
unsigned threadIndex = taskScheduler->GetCurrentThreadIndex();
if (threadIndex == 0) {
// Special case for the main thread fiber
taskScheduler->m_quitFibers[threadIndex].SwitchToFiber(&taskScheduler->m_fibers[0]);
} else {
taskScheduler->m_quitFibers[threadIndex].SwitchToFiber(&taskScheduler->m_tls[threadIndex].ThreadFiber);
}
// We should never get here
printf("Error: ThreadEndFunc should never return");
}
TaskScheduler::TaskScheduler() {
FTL_VALGRIND_HG_DISABLE_CHECKING(&m_initialized, sizeof(m_initialized));
FTL_VALGRIND_HG_DISABLE_CHECKING(&m_quit, sizeof(m_quit));
FTL_VALGRIND_HG_DISABLE_CHECKING(&m_quitCount, sizeof(m_quitCount));
}
int TaskScheduler::Init(TaskSchedulerInitOptions options) {
// Sanity check to make sure the user doesn't double init
if (m_initialized.load()) {
return kInitErrorDoubleCall;
}
m_callbacks = options.Callbacks;
// Initialize the flags
m_emptyQueueBehavior.store(options.Behavior);
if (options.ThreadPoolSize == 0) {
// 1 thread for each logical processor
m_numThreads = GetNumHardwareThreads();
} else {
m_numThreads = options.ThreadPoolSize;
}
// Create and populate the fiber pool
m_fiberPoolSize = options.FiberPoolSize;
m_fibers = new Fiber[options.FiberPoolSize];
m_freeFibers = new std::atomic<bool>[options.FiberPoolSize];
FTL_VALGRIND_HG_DISABLE_CHECKING(m_freeFibers, sizeof(std::atomic<bool>) * m_fiberPoolSize);
m_readyFiberBundles = new ReadyFiberBundle[options.FiberPoolSize];
// Leave the first slot for the bound main thread
for (unsigned i = 1; i < options.FiberPoolSize; ++i) {
m_fibers[i] = Fiber(524288, FiberStartFunc, this);
m_freeFibers[i].store(true, std::memory_order_release);
}
m_freeFibers[0].store(false, std::memory_order_release);
// Initialize threads and TLS
m_threads = new ThreadType[m_numThreads];
#ifdef _MSC_VER
# pragma warning(push)
# pragma warning(disable : 4316) // I know this won't be allocated to the right alignment, this is okay as we're using alignment for padding.
#endif // _MSC_VER
m_tls = new ThreadLocalStorage[m_numThreads];
#ifdef _MSC_VER
# pragma warning(pop)
#endif // _MSC_VER
#if defined(FTL_WIN32_THREADS)
// Temporarily set the main thread ID to -1, so when the worker threads start up, they don't accidentally use it
// I don't know if Windows thread id's can ever be 0, but just in case.
m_threads[0].Id = static_cast<DWORD>(-1);
#endif
if (m_callbacks.OnThreadsCreated != nullptr) {
m_callbacks.OnThreadsCreated(m_callbacks.Context, m_numThreads);
}
if (m_callbacks.OnFibersCreated != nullptr) {
m_callbacks.OnFibersCreated(m_callbacks.Context, options.FiberPoolSize);
}
// Set the properties for the current thread
SetCurrentThreadAffinity(0);
m_threads[0] = GetCurrentThread();
#if defined(FTL_WIN32_THREADS)
// Set the thread handle to INVALID_HANDLE_VALUE
// ::GetCurrentThread is a pseudo handle, that always references the current thread.
// Aka, if we tried to use this handle from another thread to reference the main thread,
// it would instead reference the other thread. We don't currently use the handle anywhere.
// Therefore, we set this to INVALID_HANDLE_VALUE, so any future usages can take this into account
// Reported by @rtj
m_threads[0].Handle = INVALID_HANDLE_VALUE;
#endif
// Set the fiber index
m_tls[0].CurrentFiberIndex = 0;
// Create the worker threads
for (unsigned i = 1; i < m_numThreads; ++i) {
auto *const threadArgs = new ThreadStartArgs();
threadArgs->Scheduler = this;
threadArgs->ThreadIndex = i;
char threadName[256];
snprintf(threadName, sizeof(threadName), "FTL Worker Thread %u", i);
if (!CreateThread(524288, ThreadStartFunc, threadArgs, threadName, &m_threads[i])) {
return kInitErrorFailedToCreateWorkerThread;
}
}
// Manually invoke callback for 'main' fiber
if (m_callbacks.OnFiberStateChanged != nullptr) {
m_callbacks.OnFiberStateChanged(m_callbacks.Context, 0, FiberState::Attached, false);
}
// Signal the worker threads that we're fully initialized
m_initialized.store(true, std::memory_order_release);
return 0;
}
TaskScheduler::~TaskScheduler() {
// Create the quit fibers
m_quitFibers = new Fiber[m_numThreads];
for (unsigned i = 0; i < m_numThreads; ++i) {
m_quitFibers[i] = Fiber(524288, ThreadEndFunc, this);
}
// Request that all the threads quit
m_quit.store(true, std::memory_order_release);
// Signal any waiting threads so they can finish
if (m_emptyQueueBehavior.load(std::memory_order_relaxed) == EmptyQueueBehavior::Sleep) {
ThreadSleepCV.notify_all();
}
// Jump to the quit fiber
// Create a scope so index isn't used after we come back from the switch. It will be wrong if we started on a non-main thread
{
if (m_callbacks.OnFiberStateChanged != nullptr) {
m_callbacks.OnFiberStateChanged(m_callbacks.Context, GetCurrentFiberIndex(), FiberState::Detached, false);
}
unsigned index = GetCurrentThreadIndex();
m_fibers[m_tls[index].CurrentFiberIndex].SwitchToFiber(&m_quitFibers[index]);
}
// We're back. We should be on the main thread now
// Wait for the worker threads to finish
for (unsigned i = 1; i < m_numThreads; ++i) {
JoinThread(m_threads[i]);
}
// Cleanup
delete[] m_tls;
delete[] m_threads;
delete[] m_readyFiberBundles;
delete[] m_freeFibers;
delete[] m_fibers;
delete[] m_quitFibers;
}
void TaskScheduler::AddTask(Task const task, TaskPriority priority, TaskCounter *const counter) {
FTL_ASSERT("Task given to TaskScheduler:AddTask has a nullptr Function", task.Function != nullptr);
if (counter != nullptr) {
counter->Add(1);
}
const TaskBundle bundle = {task, counter};
if (priority == TaskPriority::High) {
m_tls[GetCurrentThreadIndex()].HiPriTaskQueue.Push(bundle);
} else if (priority == TaskPriority::Low) {
m_tls[GetCurrentThreadIndex()].LoPriTaskQueue.Push(bundle);
}
const EmptyQueueBehavior behavior = m_emptyQueueBehavior.load(std::memory_order_relaxed);
if (behavior == EmptyQueueBehavior::Sleep) {
// Wake a sleeping thread
ThreadSleepCV.notify_one();
}
}
void TaskScheduler::AddTasks(unsigned const numTasks, Task const *const tasks, TaskPriority priority, TaskCounter *const counter) {
if (counter != nullptr) {
counter->Add(numTasks);
}
WaitFreeQueue<TaskBundle> *queue = nullptr;
if (priority == TaskPriority::High) {
queue = &m_tls[GetCurrentThreadIndex()].HiPriTaskQueue;
} else if (priority == TaskPriority::Low) {
queue = &m_tls[GetCurrentThreadIndex()].LoPriTaskQueue;
} else {
FTL_ASSERT("Unknown task priority", false);
return;
}
for (unsigned i = 0; i < numTasks; ++i) {
FTL_ASSERT("Task given to TaskScheduler:AddTasks has a nullptr Function", tasks[i].Function != nullptr);
const TaskBundle bundle = {tasks[i], counter};
queue->Push(bundle);
}
const EmptyQueueBehavior behavior = m_emptyQueueBehavior.load(std::memory_order_relaxed);
if (behavior == EmptyQueueBehavior::Sleep) {
// Wake all the threads
ThreadSleepCV.notify_all();
}
}
#if defined(FTL_WIN32_THREADS)
FTL_NOINLINE unsigned TaskScheduler::GetCurrentThreadIndex() const {
DWORD const threadId = ::GetCurrentThreadId();
for (unsigned i = 0; i < m_numThreads; ++i) {
if (m_threads[i].Id == threadId) {
return i;
}
}
return kInvalidIndex;
}
#elif defined(FTL_POSIX_THREADS)
FTL_NOINLINE unsigned TaskScheduler::GetCurrentThreadIndex() const {
pthread_t const currentThread = pthread_self();
for (unsigned i = 0; i < m_numThreads; ++i) {
if (pthread_equal(currentThread, m_threads[i]) != 0) {
return i;
}
}
return kInvalidIndex;
}
#endif
unsigned TaskScheduler::GetCurrentFiberIndex() const {
ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
return tls.CurrentFiberIndex;
}
inline bool TaskScheduler::TaskIsReadyToExecute(TaskBundle *bundle) const {
// "Real" tasks are always ready to execute
if (bundle->TaskToExecute.Function != ReadyFiberDummyTask) {
return true;
}
// If it's a ready fiber task, the arg is a ReadyFiberBundle
ReadyFiberBundle *readyFiberBundle = reinterpret_cast<ReadyFiberBundle *>(bundle->TaskToExecute.ArgData);
return readyFiberBundle->FiberIsSwitched.load(std::memory_order_acquire);
}
bool TaskScheduler::GetNextHiPriTask(TaskBundle *nextTask, std::vector<TaskBundle> *taskBuffer) {
unsigned const currentThreadIndex = GetCurrentThreadIndex();
ThreadLocalStorage &tls = m_tls[currentThreadIndex];
bool result = false;
// Try to pop from our own queue
while (tls.HiPriTaskQueue.Pop(nextTask)) {
if (TaskIsReadyToExecute(nextTask)) {
result = true;
// Break to cleanup
goto cleanup; // NOLINT(cppcoreguidelines-avoid-goto)
}
// It's a ReadyTask whose fiber hasn't switched away yet
// Add it to the buffer
taskBuffer->emplace_back(*nextTask);
}
// Force a scope so the `goto cleanup` above doesn't skip initialization
{
// Ours is empty, try to steal from the others'
const unsigned threadIndex = tls.HiPriLastSuccessfulSteal;
for (unsigned i = 0; i < m_numThreads; ++i) {
const unsigned threadIndexToStealFrom = (threadIndex + i) % m_numThreads;
if (threadIndexToStealFrom == currentThreadIndex) {
continue;
}
ThreadLocalStorage &otherTLS = m_tls[threadIndexToStealFrom];
while (otherTLS.HiPriTaskQueue.Steal(nextTask)) {
tls.HiPriLastSuccessfulSteal = threadIndexToStealFrom;
if (TaskIsReadyToExecute(nextTask)) {
result = true;
// Break to cleanup
goto cleanup;
}
// It's a ReadyTask whose fiber hasn't switched away yet
// Add it to the buffer
taskBuffer->emplace_back(*nextTask);
}
}
}
cleanup:
if (!taskBuffer->empty()) {
// Re-push all the tasks we found that we're ready to execute
// We (or another thread) will get them next round
do {
// Push them in the opposite order we popped them, to restore the order
tls.HiPriTaskQueue.Push(taskBuffer->back());
taskBuffer->pop_back();
} while (!taskBuffer->empty());
// If we're using Sleep mode, we need to wake up the other threads
// They may have looked for tasks while we had them all in our temp buffer and thus not
// found anything and gone to sleep.
EmptyQueueBehavior const behavior = m_emptyQueueBehavior.load(std::memory_order::memory_order_relaxed);
if (behavior == EmptyQueueBehavior::Sleep) {
// Wake all the threads
ThreadSleepCV.notify_all();
}
}
return result;
}
bool TaskScheduler::GetNextLoPriTask(TaskBundle *nextTask) {
unsigned const currentThreadIndex = GetCurrentThreadIndex();
ThreadLocalStorage &tls = m_tls[currentThreadIndex];
// Try to pop from our own queue
if (tls.LoPriTaskQueue.Pop(nextTask)) {
return true;
}
// Ours is empty, try to steal from the others'
const unsigned threadIndex = tls.LoPriLastSuccessfulSteal;
for (unsigned i = 0; i < m_numThreads; ++i) {
const unsigned threadIndexToStealFrom = (threadIndex + i) % m_numThreads;
if (threadIndexToStealFrom == currentThreadIndex) {
continue;
}
ThreadLocalStorage &otherTLS = m_tls[threadIndexToStealFrom];
if (otherTLS.LoPriTaskQueue.Steal(nextTask)) {
tls.LoPriLastSuccessfulSteal = threadIndexToStealFrom;
return true;
}
}
return false;
}
unsigned TaskScheduler::GetNextFreeFiberIndex() const {
for (unsigned j = 0;; ++j) {
for (unsigned i = 0; i < m_fiberPoolSize; ++i) {
// Double lock
if (!m_freeFibers[i].load(std::memory_order_relaxed)) {
continue;
}
if (!m_freeFibers[i].load(std::memory_order_acquire)) {
continue;
}
bool expected = true;
if (std::atomic_compare_exchange_weak_explicit(&m_freeFibers[i], &expected, false, std::memory_order_release, std::memory_order_relaxed)) {
return i;
}
}
if (j > 10) {
printf("No free fibers in the pool. Possible deadlock");
}
}
}
void TaskScheduler::CleanUpOldFiber() {
// Clean up from the last Fiber to run on this thread
//
// Explanation:
// When switching between fibers, there's the innate problem of tracking the fibers.
// For example, let's say we discover a waiting fiber that's ready. We need to put the currently
// running fiber back into the fiber pool, and then switch to the waiting fiber. However, we can't
// just do the equivalent of:
// m_fibers.Push(currentFiber)
// currentFiber.SwitchToFiber(waitingFiber)
// In the time between us adding the current fiber to the fiber pool and switching to the waiting fiber, another
// thread could come along and pop the current fiber from the fiber pool and try to run it.
// This leads to stack corruption and/or other undefined behavior.
//
// In the previous implementation of TaskScheduler, we used helper fibers to do this work for us.
// AKA, we stored currentFiber and waitingFiber in TLS, and then switched to the helper fiber. The
// helper fiber then did:
// m_fibers.Push(currentFiber)
// helperFiber.SwitchToFiber(waitingFiber)
// If we have 1 helper fiber per thread, we can guarantee that currentFiber is free to be executed by any thread
// once it is added back to the fiber pool
//
// This solution works well, however, we actually don't need the helper fibers
// The code structure guarantees that between any two fiber switches, the code will always end up in WaitForCounter
// or FiberStart. Therefore, instead of using a helper fiber and immediately pushing the fiber to the fiber pool or
// waiting list, we defer the push until the next fiber gets to one of those two places
//
// Proof:
// There are only two places where we switch fibers:
// 1. When we're waiting for a counter, we pull a new fiber from the fiber pool and switch to it.
// 2. When we found a counter that's ready, we put the current fiber back in the fiber pool, and switch to the
// waiting fiber.
//
// Case 1:
// A fiber from the pool will always either be completely new or just come back from switching to a waiting fiber
// In both places, we call CleanUpOldFiber()
// QED
//
// Case 2:
// A waiting fiber will always resume in WaitForCounter()
// Here, we call CleanUpOldFiber()
// QED
ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
switch (tls.OldFiberDestination) {
case FiberDestination::ToPool:
// In this specific implementation, the fiber pool is a flat array signaled by atomics
// So in order to "Push" the fiber to the fiber pool, we just set its corresponding atomic to true
m_freeFibers[tls.OldFiberIndex].store(true, std::memory_order_release);
tls.OldFiberDestination = FiberDestination::None;
tls.OldFiberIndex = kInvalidIndex;
break;
case FiberDestination::ToWaiting:
// The waiting fibers are stored directly in their counters
// They have an atomic<bool> that signals whether the waiting fiber can be consumed if it's ready
// We just have to set it to true
tls.OldFiberStoredFlag->store(true, std::memory_order_release);
tls.OldFiberDestination = FiberDestination::None;
tls.OldFiberIndex = kInvalidIndex;
break;
case FiberDestination::None:
default:
break;
}
}
void TaskScheduler::AddReadyFiber(unsigned const pinnedThreadIndex, ReadyFiberBundle *bundle) {
if (pinnedThreadIndex == kNoThreadPinning) {
ThreadLocalStorage *tls = &m_tls[GetCurrentThreadIndex()];
// Push a dummy task to the high priority queue
Task task{};
task.Function = ReadyFiberDummyTask;
task.ArgData = bundle;
TaskBundle taskBundle{};
taskBundle.TaskToExecute = task;
taskBundle.Counter = nullptr;
tls->HiPriTaskQueue.Push(taskBundle);
// If we're using EmptyQueueBehavior::Sleep, the other threads could be sleeping
// Therefore, we need to kick a thread awake to ensure that the readied task is taken
const EmptyQueueBehavior behavior = m_emptyQueueBehavior.load(std::memory_order_relaxed);
if (behavior == EmptyQueueBehavior::Sleep) {
ThreadSleepCV.notify_one();
}
} else {
ThreadLocalStorage *tls = &m_tls[pinnedThreadIndex];
{
std::lock_guard<std::mutex> guard(tls->PinnedReadyFibersLock);
tls->PinnedReadyFibers.emplace_back(bundle);
}
// If the Task is pinned, we add the Task to the pinned thread's PinnedReadyFibers queue
// Normally, this works fine; the other thread will pick it up next time it
// searches for a Task to run.
//
// However, if we're using EmptyQueueBehavior::Sleep, the other thread could be sleeping
// Therefore, we need to kick all the threads so that the pinned-to thread can take it
const EmptyQueueBehavior behavior = m_emptyQueueBehavior.load(std::memory_order::memory_order_relaxed);
if (behavior == EmptyQueueBehavior::Sleep) {
if (GetCurrentThreadIndex() != pinnedThreadIndex) {
std::unique_lock<std::mutex> lock(ThreadSleepLock);
// Kick all threads
ThreadSleepCV.notify_all();
}
}
}
}
void TaskScheduler::WaitForCounter(TaskCounter *counter, bool pinToCurrentThread) {
WaitForCounterInternal(counter, 0, pinToCurrentThread);
}
void TaskScheduler::WaitForCounter(AtomicFlag *counter, bool pinToCurrentThread) {
WaitForCounterInternal(counter, 0, pinToCurrentThread);
}
void TaskScheduler::WaitForCounter(FullAtomicCounter *counter, unsigned value, bool pinToCurrentThread) {
WaitForCounterInternal(counter, value, pinToCurrentThread);
}
void TaskScheduler::WaitForCounterInternal(BaseCounter *counter, unsigned value, bool pinToCurrentThread) {
// Fast out
if (counter->m_value.load(std::memory_order_relaxed) == value) {
// wait for threads to drain from counter logic, otherwise we might continue too early
while (counter->m_lock.load() > 0) {
}
return;
}
ThreadLocalStorage &tls = m_tls[GetCurrentThreadIndex()];
unsigned const currentFiberIndex = tls.CurrentFiberIndex;
unsigned pinnedThreadIndex;
if (pinToCurrentThread) {
pinnedThreadIndex = GetCurrentThreadIndex();
} else {
pinnedThreadIndex = kNoThreadPinning;
}
// Create the ready fiber bundle and attempt to add it to the waiting list
ReadyFiberBundle *readyFiberBundle = &m_readyFiberBundles[currentFiberIndex];
readyFiberBundle->FiberIndex = currentFiberIndex;
readyFiberBundle->FiberIsSwitched.store(false);
bool const alreadyDone = counter->AddFiberToWaitingList(readyFiberBundle, value, pinnedThreadIndex);
// The counter finished while we were trying to put it in the waiting list
// Just trivially return
if (alreadyDone) {
return;
}
// Get a free fiber
unsigned const freeFiberIndex = GetNextFreeFiberIndex();
// Fill in tls
tls.OldFiberIndex = currentFiberIndex;
tls.CurrentFiberIndex = freeFiberIndex;
tls.OldFiberDestination = FiberDestination::ToWaiting;
tls.OldFiberStoredFlag = &readyFiberBundle->FiberIsSwitched;
if (m_callbacks.OnFiberStateChanged != nullptr) {
m_callbacks.OnFiberStateChanged(m_callbacks.Context, currentFiberIndex, FiberState::Detached, true);
}
// Switch
m_fibers[currentFiberIndex].SwitchToFiber(&m_fibers[freeFiberIndex]);
if (m_callbacks.OnFiberStateChanged != nullptr) {
m_callbacks.OnFiberStateChanged(m_callbacks.Context, GetCurrentFiberIndex(), FiberState::Attached, false);
}
// And we're back
CleanUpOldFiber();
}
} // End of namespace ftl
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment