Last active
May 20, 2022 07:02
-
-
Save liuliu/7366373d0824a915a26ff295c468b6e4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <stdint.h> | |
#include <stdlib.h> | |
#include <ucontext.h> | |
#include <pthread.h> | |
#include "nnc/gpu/ccv_nnc_compat.h" | |
union ptr_splitter { | |
void *ptr; | |
uint32_t part[2]; | |
}; | |
static const int default_stack_size = 65536; | |
typedef struct schd_s schd_t; | |
typedef struct task_s task_t; | |
typedef void (*task_fn_t)(task_t *task); | |
struct task_s { | |
struct task_s* prev; | |
struct task_s* next; | |
schd_t* schd; | |
int done; | |
struct task_s* waitfor; | |
// For swapcontext / makecontext / getcontext. | |
ucontext_t context; | |
char *stack; | |
task_fn_t fn; | |
}; | |
struct schd_s { | |
task_t* head; | |
task_t* tail; | |
struct { | |
int suspend; | |
} count; | |
pthread_cond_t cv; | |
pthread_mutex_t mutex; | |
ucontext_t caller, callee; | |
}; | |
static void addtask(schd_t* const schd, task_t* const t) | |
{ | |
if (schd->tail) | |
{ | |
schd->tail->next = t; | |
t->prev = schd->tail; | |
} else { | |
schd->head = t; | |
t->prev = 0; | |
} | |
schd->tail = t; | |
t->next = 0; | |
} | |
static void deltask(schd_t* const schd, task_t* const t) | |
{ | |
if (t->prev) | |
t->prev->next = t->next; | |
else | |
schd->head = t->next; | |
if (t->next) | |
t->next->prev = t->prev; | |
else | |
schd->tail = t->prev; | |
} | |
static void _task_entry_point(uint32_t part0, uint32_t part1) | |
{ | |
union ptr_splitter p; | |
p.part[0] = part0; | |
p.part[1] = part1; | |
task_t *task = (task_t*)p.ptr; | |
task->fn(task); | |
task->done = 1; | |
swapcontext(&task->schd->callee, &task->schd->caller); | |
} | |
static task_t* taskcreate(schd_t* const schd, task_fn_t fn) | |
{ | |
task_t *task = (task_t*)calloc(1, sizeof(task_t)); | |
task->schd = schd; | |
task->stack = (char*)calloc(1, default_stack_size); | |
task->fn = fn; | |
getcontext(&task->context); | |
task->context.uc_stack.ss_sp = task->stack; | |
task->context.uc_stack.ss_size = default_stack_size; | |
task->context.uc_link = 0; | |
union ptr_splitter p; | |
p.ptr = task; | |
makecontext(&task->context, (void (*)(void))_task_entry_point, 2, p.part[0], p.part[1]); | |
return task; | |
} | |
static void taskfree(task_t* const task) | |
{ | |
task_t* waitfor = task->waitfor; | |
while (waitfor) | |
{ | |
task_t* const next = waitfor->next; | |
addtask(task->schd, waitfor); | |
waitfor = next; | |
} | |
free(task->stack); | |
free(task); | |
} | |
static void taskyield(task_t* const task) | |
{ | |
addtask(task->schd, task); | |
swapcontext(&task->schd->callee, &task->schd->caller); | |
} | |
static void* schdmain(void* userdata) | |
{ | |
schd_t* const schd = (schd_t*)userdata; | |
for (;;) { | |
pthread_mutex_lock(&schd->mutex); | |
// No one is waiting, and no more tasks. exit. | |
if (schd->head == 0 && schd->count.suspend == 0) | |
{ | |
pthread_mutex_unlock(&schd->mutex); | |
break; | |
} | |
if (schd->head == 0) | |
{ | |
pthread_cond_wait(&schd->cv, &schd->mutex); | |
pthread_mutex_unlock(&schd->mutex); | |
continue; | |
} | |
task_t* const t = schd->head; | |
deltask(schd, t); | |
pthread_mutex_unlock(&schd->mutex); | |
swapcontext(&schd->caller, &t->context); | |
t->context = schd->callee; | |
if (t->done) | |
taskfree(t); | |
} | |
return 0; | |
} | |
static void taskcudaresume(cudaStream_t stream, cudaError_t status, void* userdata) | |
{ | |
task_t* const task = (task_t*)userdata; | |
pthread_mutex_lock(&task->schd->mutex); | |
addtask(task->schd, task); | |
--task->schd->count.suspend; | |
pthread_cond_signal(&task->schd->cv); | |
pthread_mutex_unlock(&task->schd->mutex); | |
} | |
static void taskcudawait(task_t* const task, cudaStream_t stream) | |
{ | |
pthread_mutex_lock(&task->schd->mutex); | |
++task->schd->count.suspend; | |
cudaStreamAddCallback(stream, taskcudaresume, task, 0); | |
pthread_mutex_unlock(&task->schd->mutex); | |
// Compare to taskyield, this function doesn't do addtask(task->schd, task); | |
swapcontext(&task->schd->callee, &task->schd->caller); | |
} | |
// Run this task directly. | |
static void taskresume(task_t* const task) | |
{ | |
ucontext_t old_context = task->schd->caller; | |
swapcontext(&task->schd->caller, &task->context); | |
task->context = task->schd->callee; | |
task->schd->caller = old_context; | |
if (task->done) // If the task is done here, we should just remove it. | |
taskfree(task); | |
} | |
// This method must execute in the context of task. | |
static void taskwait(task_t* const task, task_t* const waiton) | |
{ | |
task->prev = 0; | |
task->next = waiton->waitfor; | |
waiton->waitfor = task; | |
swapcontext(&task->schd->callee, &task->schd->caller); | |
} | |
static void g(task_t* const task) | |
{ | |
printf("start task %p\n", task); | |
taskyield(task); | |
printf("back to task %p to finish\n", task); | |
} | |
__global__ void _test_kernel(const int batch_size, const float* const a, float* const b) | |
{ | |
CUDA_1D_KERNEL_LOOP(i, batch_size) { | |
b[i] = -log(a[i]); | |
} | |
} | |
static void f(task_t* const task) | |
{ | |
cudaStream_t stream0; | |
cudaStreamCreate(&stream0); | |
const int batch_size = 1024; | |
float* a; | |
float* b; | |
cudaMalloc(&a, sizeof(float) * 1024); | |
cudaMalloc(&b, sizeof(float) * 1024); | |
printf("allocated, launch kernel for %p\n", task); | |
_test_kernel<<<CUDA_GET_BLOCKS(batch_size), CUDA_NUM_THREADS, 0, stream0>>>(batch_size, a, b); | |
printf("create a new task to resume\n"); | |
task_t* gtask = taskcreate(task->schd, g); | |
taskresume(gtask); // Run the gtask directly. | |
taskwait(task, gtask); // this will wait until return back to g. | |
printf("resume to task %p\n", task); | |
printf("going to wait for stream\n"); | |
taskcudawait(task, stream0); | |
printf("stream done\n"); | |
_test_kernel<<<CUDA_GET_BLOCKS(batch_size), CUDA_NUM_THREADS, 0, stream0>>>(batch_size, a, b); | |
cudaError_t err = cudaGetLastError(); | |
if (err != cudaSuccess) | |
printf("error %s\n", cudaGetErrorString(err)); | |
printf("done cuda task\n"); | |
cudaStreamDestroy(stream0); | |
} | |
int main(void) | |
{ | |
schd_t schd = {}; | |
pthread_cond_init(&schd.cv, 0); | |
pthread_mutex_init(&schd.mutex, 0); | |
task_t* task = taskcreate(&schd, f); | |
addtask(&schd, task); | |
schdmain(&schd); | |
cudaDeviceSynchronize(); | |
pthread_cond_destroy(&schd.cv); | |
pthread_mutex_destroy(&schd.mutex); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment