Skip to content

Instantly share code, notes, and snippets.

@usbuild
Created November 7, 2018 09:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save usbuild/ba21ff0079264260a222085e45615a71 to your computer and use it in GitHub Desktop.
Save usbuild/ba21ff0079264260a222085e45615a71 to your computer and use it in GitHub Desktop.
A Simple cpp coroutine implementation
#include "co.hpp"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <map>
#include <mutex>
using namespace pm::common;
static void co_wrap_main(void) {
__asm__ __volatile__("\tmovq %r13, %rdi\n" // %rdi is the first argument
"\tjmpq *%r12\n");
}
static inline void co_jump(co_jmp_buf from, co_jmp_buf to) {
__asm__ __volatile__("leaq 1f(%%rip), %%rax\n\t"
"movq %%rax, (%0)\n\t"
"movq %%rsp, 8(%0)\n\t"
"movq %%rbp, 16(%0)\n\t"
"movq %%rbx, 24(%0)\n\t"
"movq %%r12, 32(%0)\n\t"
"movq %%r13, 40(%0)\n\t"
"movq %%r14, 48(%0)\n\t"
"movq %%r15, 56(%0)\n\t"
"movq 56(%1), %%r15\n\t"
"movq 48(%1), %%r14\n\t"
"movq 40(%1), %%r13\n\t"
"movq 32(%1), %%r12\n\t"
"movq 24(%1), %%rbx\n\t"
"movq 16(%1), %%rbp\n\t"
"movq 8(%1), %%rsp\n\t"
"jmpq *(%1)\n"
"1:\n"
: "+S"(from), "+D"(to)
:
: "rax", "rcx", "rdx", "r8", "r9", "r10", "r11", "memory", "cc");
}
std::map<uintptr_t, Coroutine*> g_co_map;
std::mutex g_co_mu;
#define CO_STACK_SIZE (16 * 1024 * 1024)
Coroutine::Coroutine(start_routine_t start_routine) {
stack_ = malloc(CO_STACK_SIZE);
sp_ = (char *)(stack_) + CO_STACK_SIZE;
start_routine_ = start_routine;
status_ = Status::CREATE;
force_unwind_ = false;
std::lock_guard<std::mutex> lk(g_co_mu);
g_co_map[(uintptr_t)stack_] = this;
}
Coroutine *Coroutine::GetCoroutine() {
std::lock_guard<std::mutex> lk(g_co_mu);
if (g_co_map.empty()) return nullptr;
int dummy;
auto it = g_co_map.lower_bound((uintptr_t)&dummy);
if (it == g_co_map.begin()) return nullptr;
it--;
Coroutine *co = it->second;
if (&dummy > co->stack_ && co->sp_ >= &dummy) {
return co;
} else {
return nullptr;
}
}
Coroutine::~Coroutine() {
if (status_ == Status::SUSPEND) {
force_unwind_ = true;
resume(NULL);
}
free(stack_);
std::lock_guard<std::mutex> lk(g_co_mu);
g_co_map.erase((uintptr_t)stack_);
}
void *Coroutine::resume(void *arg) {
assert(status_ == Status::CREATE || status_ == Status::SUSPEND);
co_jmp_buf target;
resume_arg_ = arg;
if (status_ == Status::CREATE) {
init_jmp_buf_(target);
} else if (status_ == Status::SUSPEND) {
memcpy(target, saved_ctx_, sizeof(target));
}
status_ = Status::RUNNING;
co_jump(saved_ctx_, target);
if (status_ != Status::EXIT)
status_ = Status::SUSPEND;
return yield_arg_;
}
void *Coroutine::yield(void *ret) {
yield_arg_ = ret;
co_jmp_buf target;
memcpy(target, saved_ctx_, sizeof(target));
co_jump(saved_ctx_, target);
if (force_unwind_) {
throw ForceUnwind{};
}
return resume_arg_;
}
void Coroutine::init_jmp_buf_(co_jmp_buf regs) {
regs[0] = (void *)(co_wrap_main);
regs[1] = this->sp_;
regs[2] = NULL;
regs[3] = NULL;
regs[4] = reinterpret_cast<void *>(+[](Coroutine *self) {
try {
self->yield_arg_ = self->start_routine_(self->resume_arg_);
} catch (const ForceUnwind &) {
}
self->status_ = Status::EXIT;
co_jmp_buf tmp;
co_jump(tmp, self->saved_ctx_);
});
regs[5] = this;
regs[6] = NULL;
regs[7] = NULL;
}
#pragma once
#include <stdint.h>
namespace pm {
namespace common {
typedef void *co_jmp_buf[8]; /* rip, rsp, rbp, rbx, r12, r13, r14, r15 */
class Coroutine {
typedef void *(*start_routine_t)(void *);
public:
enum class Status { CREATE, SUSPEND, RUNNING, EXIT };
struct ForceUnwind {};
static Coroutine *GetCoroutine();
private:
co_jmp_buf saved_ctx_;
void *stack_;
void *sp_;
start_routine_t start_routine_;
Status status_;
void *resume_arg_;
void *yield_arg_;
bool force_unwind_ = false;
const void *tag_;
void init_jmp_buf_(co_jmp_buf regs);
public:
Coroutine(start_routine_t start_routine);
~Coroutine();
void *resume(void *arg);
void *yield(void *ret);
Status state() const { return this->status_; }
void setTag(const void *tag) { this->tag_ = tag; }
const void *getTag() const { return this->tag_; }
};
} /* common */
} /* pm */
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment