Skip to content

Instantly share code, notes, and snippets.

@liuliu
Last active May 20, 2022 07:02
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save liuliu/7366373d0824a915a26ff295c468b6e4 to your computer and use it in GitHub Desktop.
Save liuliu/7366373d0824a915a26ff295c468b6e4 to your computer and use it in GitHub Desktop.
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <ucontext.h>
#include <pthread.h>
#include "nnc/gpu/ccv_nnc_compat.h"
union ptr_splitter {
void *ptr;
uint32_t part[2];
};
static const int default_stack_size = 65536;
typedef struct schd_s schd_t;
typedef struct task_s task_t;
typedef void (*task_fn_t)(task_t *task);
struct task_s {
struct task_s* prev;
struct task_s* next;
schd_t* schd;
int done;
struct task_s* waitfor;
// For swapcontext / makecontext / getcontext.
ucontext_t context;
char *stack;
task_fn_t fn;
};
struct schd_s {
task_t* head;
task_t* tail;
struct {
int suspend;
} count;
pthread_cond_t cv;
pthread_mutex_t mutex;
ucontext_t caller, callee;
};
static void addtask(schd_t* const schd, task_t* const t)
{
if (schd->tail)
{
schd->tail->next = t;
t->prev = schd->tail;
} else {
schd->head = t;
t->prev = 0;
}
schd->tail = t;
t->next = 0;
}
static void deltask(schd_t* const schd, task_t* const t)
{
if (t->prev)
t->prev->next = t->next;
else
schd->head = t->next;
if (t->next)
t->next->prev = t->prev;
else
schd->tail = t->prev;
}
static void _task_entry_point(uint32_t part0, uint32_t part1)
{
union ptr_splitter p;
p.part[0] = part0;
p.part[1] = part1;
task_t *task = (task_t*)p.ptr;
task->fn(task);
task->done = 1;
swapcontext(&task->schd->callee, &task->schd->caller);
}
static task_t* taskcreate(schd_t* const schd, task_fn_t fn)
{
task_t *task = (task_t*)calloc(1, sizeof(task_t));
task->schd = schd;
task->stack = (char*)calloc(1, default_stack_size);
task->fn = fn;
getcontext(&task->context);
task->context.uc_stack.ss_sp = task->stack;
task->context.uc_stack.ss_size = default_stack_size;
task->context.uc_link = 0;
union ptr_splitter p;
p.ptr = task;
makecontext(&task->context, (void (*)(void))_task_entry_point, 2, p.part[0], p.part[1]);
return task;
}
static void taskfree(task_t* const task)
{
task_t* waitfor = task->waitfor;
while (waitfor)
{
task_t* const next = waitfor->next;
addtask(task->schd, waitfor);
waitfor = next;
}
free(task->stack);
free(task);
}
static void taskyield(task_t* const task)
{
addtask(task->schd, task);
swapcontext(&task->schd->callee, &task->schd->caller);
}
static void* schdmain(void* userdata)
{
schd_t* const schd = (schd_t*)userdata;
for (;;) {
pthread_mutex_lock(&schd->mutex);
// No one is waiting, and no more tasks. exit.
if (schd->head == 0 && schd->count.suspend == 0)
{
pthread_mutex_unlock(&schd->mutex);
break;
}
if (schd->head == 0)
{
pthread_cond_wait(&schd->cv, &schd->mutex);
pthread_mutex_unlock(&schd->mutex);
continue;
}
task_t* const t = schd->head;
deltask(schd, t);
pthread_mutex_unlock(&schd->mutex);
swapcontext(&schd->caller, &t->context);
t->context = schd->callee;
if (t->done)
taskfree(t);
}
return 0;
}
static void taskcudaresume(cudaStream_t stream, cudaError_t status, void* userdata)
{
task_t* const task = (task_t*)userdata;
pthread_mutex_lock(&task->schd->mutex);
addtask(task->schd, task);
--task->schd->count.suspend;
pthread_cond_signal(&task->schd->cv);
pthread_mutex_unlock(&task->schd->mutex);
}
static void taskcudawait(task_t* const task, cudaStream_t stream)
{
pthread_mutex_lock(&task->schd->mutex);
++task->schd->count.suspend;
cudaStreamAddCallback(stream, taskcudaresume, task, 0);
pthread_mutex_unlock(&task->schd->mutex);
// Compare to taskyield, this function doesn't do addtask(task->schd, task);
swapcontext(&task->schd->callee, &task->schd->caller);
}
// Run this task directly.
static void taskresume(task_t* const task)
{
ucontext_t old_context = task->schd->caller;
swapcontext(&task->schd->caller, &task->context);
task->context = task->schd->callee;
task->schd->caller = old_context;
if (task->done) // If the task is done here, we should just remove it.
taskfree(task);
}
// This method must execute in the context of task.
static void taskwait(task_t* const task, task_t* const waiton)
{
task->prev = 0;
task->next = waiton->waitfor;
waiton->waitfor = task;
swapcontext(&task->schd->callee, &task->schd->caller);
}
static void g(task_t* const task)
{
printf("start task %p\n", task);
taskyield(task);
printf("back to task %p to finish\n", task);
}
__global__ void _test_kernel(const int batch_size, const float* const a, float* const b)
{
CUDA_1D_KERNEL_LOOP(i, batch_size) {
b[i] = -log(a[i]);
}
}
static void f(task_t* const task)
{
cudaStream_t stream0;
cudaStreamCreate(&stream0);
const int batch_size = 1024;
float* a;
float* b;
cudaMalloc(&a, sizeof(float) * 1024);
cudaMalloc(&b, sizeof(float) * 1024);
printf("allocated, launch kernel for %p\n", task);
_test_kernel<<<CUDA_GET_BLOCKS(batch_size), CUDA_NUM_THREADS, 0, stream0>>>(batch_size, a, b);
printf("create a new task to resume\n");
task_t* gtask = taskcreate(task->schd, g);
taskresume(gtask); // Run the gtask directly.
taskwait(task, gtask); // this will wait until return back to g.
printf("resume to task %p\n", task);
printf("going to wait for stream\n");
taskcudawait(task, stream0);
printf("stream done\n");
_test_kernel<<<CUDA_GET_BLOCKS(batch_size), CUDA_NUM_THREADS, 0, stream0>>>(batch_size, a, b);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess)
printf("error %s\n", cudaGetErrorString(err));
printf("done cuda task\n");
cudaStreamDestroy(stream0);
}
int main(void)
{
schd_t schd = {};
pthread_cond_init(&schd.cv, 0);
pthread_mutex_init(&schd.mutex, 0);
task_t* task = taskcreate(&schd, f);
addtask(&schd, task);
schdmain(&schd);
cudaDeviceSynchronize();
pthread_cond_destroy(&schd.cv);
pthread_mutex_destroy(&schd.mutex);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment