Skip to content

Instantly share code, notes, and snippets.

@cloudhan
Last active May 15, 2022 13:15
Show Gist options
  • Save cloudhan/90b57afa920070206180e31179bd6828 to your computer and use it in GitHub Desktop.
Save cloudhan/90b57afa920070206180e31179bd6828 to your computer and use it in GitHub Desktop.
Minimal sender/receiver implementation by Eric Niebler https://twitter.com/ericniebler/status/1525651424951947264
// copied from https://godbolt.org/z/r1jdTY4G8
/*
* Copyright 2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <exception>
#include <cstdio>
#include <variant>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
#include <sstream>
#include <utility>
#include <optional>
// For debugging
std::string get_thread_id() {
std::stringstream sout;
sout << "0x" << std::hex << std::this_thread::get_id();
return sout.str();
}
// In this toy implementation, a sender can only complete with a single value, or void.
template <class Snd>
using sender_result_t = typename Snd::result_t;
template <class Snd, class Rcvr>
using connect_result_t = decltype(connect(std::declval<Snd>(), std::declval<Rcvr>()));
///////////////////////////////////////////
// just(T) sender factory
///////////////////////////////////////////
template <class T, class Rcvr>
struct just_operation {
T value_;
Rcvr rcvr_;
friend void start(just_operation& self) {
set_value(self.rcvr_, self.value_);
}
};
template <class T>
struct just_sender {
using result_t = T;
T value_;
template <class Rcvr>
friend just_operation<T, Rcvr> connect(just_sender self, Rcvr rcvr) {
return {self.value_, rcvr};
}
};
///////////////////////////////////////////
// then(Sender, Function) sender adaptor
///////////////////////////////////////////
template <class Fun, class Rcvr>
struct then_receiver {
Fun fun_;
Rcvr rcvr_;
friend void set_value(then_receiver self, auto... val) {
set_value(self.rcvr_, self.fun_(val...));
}
friend void set_error(then_receiver self, std::exception_ptr err) {
set_error(self.rcvr_, err);
}
friend void set_stopped(then_receiver self) {
set_stopped(self.rcvr_);
}
};
// Handle void completions:
template <class Fun, class... Vs>
struct result : std::invoke_result<Fun, Vs...> {};
template <class Fun>
struct result<Fun, void> : std::invoke_result<Fun> {};
template <class PrevSnd, class Fun>
struct then_sender {
using result_t = typename result<Fun, sender_result_t<PrevSnd>>::type;
PrevSnd prev_;
Fun fun_;
template <class Rcvr>
using state_for_ = connect_result_t<PrevSnd, then_receiver<Fun, Rcvr>>;
template <class Rcvr>
friend state_for_<Rcvr> connect(then_sender self, Rcvr rcvr) {
return connect(self.prev_, then_receiver<Fun, Rcvr>{self.fun_, rcvr});
}
};
template <class PrevSnd, class Fun>
then_sender<PrevSnd, Fun> then(PrevSnd prev, Fun fun) {
return {prev, fun};
}
///////////////////////////////////////////
// sync_wait() sender consumer
///////////////////////////////////////////
template <class T>
struct sync_wait_receiver {
std::condition_variable& cv_;
std::optional<T>& value_;
friend void set_value(sync_wait_receiver self, T val) {
self.value_.emplace(val);
self.cv_.notify_all();
}
friend void set_error(sync_wait_receiver self, std::exception_ptr err) {
self.err_ = err;
self.cv_.notify_all();
}
friend void set_stopped(sync_wait_receiver self) {
self.cv_.notify_all();
}
};
template <class Snd>
std::optional<sender_result_t<Snd>> sync_wait(Snd snd) {
std::mutex mtx;
std::unique_lock<std::mutex> lk{mtx};
std::condition_variable cv;
std::exception_ptr err_;
std::optional<sender_result_t<Snd>> value;
auto op = connect(snd, sync_wait_receiver<sender_result_t<Snd>>{cv, value});
start(op);
cv.wait(lk);
if(err_)
std::rethrow_exception(err_);
return value;
}
struct run_loop {
private:
struct operation_interface {
operation_interface* next_ = nullptr;
virtual ~operation_interface() = default;
virtual void run() = 0;
};
template <class Rcvr>
struct operation_model : operation_interface {
Rcvr rcvr_;
run_loop& ctx_;
operation_model(Rcvr rcvr, run_loop& ctx)
: rcvr_(rcvr), ctx_(ctx) {}
void run() override {
std::printf("Running task on thread: %s\n", get_thread_id().c_str());
set_value(rcvr_);
}
friend void start(operation_model& self) {
self.start_();
}
void start_() {
ctx_.push_back_(this);
}
};
operation_interface* head_ = nullptr;
operation_interface** tail_ = &head_;
bool finish_ = false;
std::mutex mtx_;
std::condition_variable cv_;
std::thread thread_;
void push_back_(operation_interface* op) {
std::unique_lock<std::mutex> lk(mtx_);
*std::exchange(tail_, &op->next_) = op;
cv_.notify_one();
}
operation_interface* pop_front_() {
std::unique_lock<std::mutex> lk(mtx_);
while (nullptr == head_) {
if (finish_)
return nullptr;
cv_.wait(lk);
}
auto* op = std::exchange(head_, head_->next_);
if (head_ == nullptr)
tail_ = &head_;
return op;
}
struct sender {
using result_t = void;
run_loop& ctx_;
template <class Rcvr>
friend operation_model<Rcvr> connect(sender self, Rcvr rcvr) {
return self.connect_(rcvr);
}
template <class Rcvr>
operation_model<Rcvr> connect_(Rcvr rcvr) {
return {rcvr, ctx_};
}
};
struct scheduler {
run_loop& ctx_;
friend sender schedule(scheduler self) {
return {self.ctx_};
}
};
public:
void run() {
while (auto* op = pop_front_()) {
op->run();
}
}
scheduler get_scheduler() {
return {*this};
}
void finish() {
std::unique_lock<std::mutex> lk(mtx_);
finish_ = true;
cv_.notify_all();
}
};
class thread_context {
private:
run_loop loop_;
std::thread thread_;
public:
thread_context()
: thread_([this]{ loop_.run(); })
{}
void finish() {
loop_.finish();
}
void join() {
thread_.join();
}
auto get_scheduler() {
return loop_.get_scheduler();
}
};
int main() {
thread_context ctx;
std::printf("main thread: %s\n", get_thread_id().c_str());
//just_sender<int> first{42};
auto first = then(schedule(ctx.get_scheduler()), [] { return 42;} );
auto next = then(first, [](int i) {return i+1;} );
auto last = then(next, [](int i) {return i+1;} );
int i = sync_wait(last).value();
std::printf("result: %d\n", i);
ctx.finish();
ctx.join();
}
// copied from https://godbolt.org/z/3hETWr6hE
/*
* Copyright 2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This is a toy implementation of the core parts of the C++ std::execution
// proposal (aka, Executors, http://wg21.link/P2300). It is intended to be a
// learning tool only. THIS CODE IS NOT SUITABLE FOR ANY USE.
#include <condition_variable>
#include <cstdio>
#include <exception>
#include <functional>
#include <mutex>
#include <optional>
#include <sstream>
#include <thread>
#include <utility>
// Some utility code
///////////////////////////////////////////
std::string get_thread_id() {
std::stringstream sout;
sout.imbue(std::locale::classic());
sout << "0x" << std::hex << std::this_thread::get_id();
return sout.str();
}
struct immovable {
immovable() = default;
immovable(immovable&&) = delete;
};
struct none {};
// In this toy implementation, a sender can only complete with a single value.
template <class Snd>
using sender_result_t = typename Snd::result_t;
///////////////////////////////////////////
// just(T) sender factory
///////////////////////////////////////////
template <class T, class Rcvr>
struct just_operation : immovable {
T value_;
Rcvr rcvr_;
friend void start(just_operation& self) {
set_value(self.rcvr_, self.value_);
}
};
template <class T>
struct just_sender {
using result_t = T;
T value_;
template <class Rcvr>
friend just_operation<T, Rcvr> connect(just_sender self, Rcvr rcvr) {
return {{}, self.value_, rcvr};
}
};
///////////////////////////////////////////
// then(Sender, Function) sender adaptor
///////////////////////////////////////////
template <class Fun, class Rcvr>
struct then_receiver {
Fun fun_;
Rcvr rcvr_;
friend void set_value(then_receiver self, auto val) try {
set_value(self.rcvr_, self.fun_(val));
} catch(...) {
set_error(self.rcvr_, std::current_exception());
}
friend void set_error(then_receiver self, std::exception_ptr err) {
set_error(self.rcvr_, err);
}
friend void set_stopped(then_receiver self) {
set_stopped(self.rcvr_);
}
};
template <class PrevSnd, class Fun>
struct then_sender {
using prev_result_t = sender_result_t<PrevSnd>;
using result_t = std::invoke_result_t<Fun, prev_result_t>;
PrevSnd prev_;
Fun fun_;
template <class Rcvr>
friend auto connect(then_sender self, Rcvr rcvr) {
return connect(self.prev_, then_receiver<Fun, Rcvr>{self.fun_, rcvr});
}
};
template <class PrevSnd, class Fun>
then_sender<PrevSnd, Fun> then(PrevSnd prev, Fun fun) {
return {prev, fun};
}
///////////////////////////////////////////
// sync_wait() sender consumer
///////////////////////////////////////////
struct sync_wait_data {
std::mutex mtx;
std::condition_variable cv;
std::exception_ptr err;
bool completed = false;
};
template <class T>
struct sync_wait_receiver {
sync_wait_data& data_;
std::optional<T>& value_;
friend void set_value(sync_wait_receiver self, T val) {
std::unique_lock lk(self.data_.mtx);
self.value_.emplace(val);
self.data_.completed = true;
self.data_.cv.notify_all();
}
friend void set_error(sync_wait_receiver self, std::exception_ptr err) {
std::unique_lock lk(self.data_.mtx);
self.data_.err = err;
self.data_.completed = true;
self.data_.cv.notify_all();
}
friend void set_stopped(sync_wait_receiver self) {
std::unique_lock lk(self.data_.mtx);
self.data_.completed = true;
self.data_.cv.notify_all();
}
};
template <class Snd>
std::optional<sender_result_t<Snd>> sync_wait(Snd snd) {
using T = sender_result_t<Snd>;
sync_wait_data data;
std::optional<T> value;
auto op = connect(snd, sync_wait_receiver<T>{data, value});
start(op);
std::unique_lock lk{data.mtx};
data.cv.wait(lk, [&data]{ return data.completed; });
if (data.err)
std::rethrow_exception(data.err);
return value;
}
///////////////////////////////////////////
// run_loop execution context
///////////////////////////////////////////
struct run_loop : private immovable {
private:
struct operation_base : private immovable {
operation_base* next_ = this;
virtual void execute() {}
};
template <class Rcvr>
struct operation : operation_base {
Rcvr rcvr_;
run_loop& ctx_;
operation(Rcvr rcvr, run_loop& ctx)
: rcvr_(rcvr), ctx_(ctx) {}
void execute() override final {
std::printf("Running task on thread: %s\n", get_thread_id().c_str());
set_value(rcvr_, none{});
}
friend void start(operation& self) {
self.start_();
}
void start_() {
ctx_.push_back_(this);
}
};
operation_base head_;
operation_base* tail_ = &head_;
bool finish_ = false;
std::mutex mtx_;
std::condition_variable cv_;
void push_back_(operation_base* op) {
std::unique_lock lk(mtx_);
op->next_ = &head_;
tail_ = tail_->next_= op;
cv_.notify_one();
}
operation_base* pop_front_() {
std::unique_lock lk(mtx_);
// loop while the queue is empty:
for (; head_.next_ == &head_; cv_.wait(lk))
if (finish_)
return nullptr;
return std::exchange(head_.next_, head_.next_->next_);
}
struct sender {
using result_t = none;
run_loop& ctx_;
template <class Rcvr>
friend operation<Rcvr> connect(sender self, Rcvr rcvr) {
return self.connect_(rcvr);
}
template <class Rcvr>
operation<Rcvr> connect_(Rcvr rcvr) {
return {rcvr, ctx_};
}
};
struct scheduler {
run_loop& ctx_;
friend sender schedule(scheduler self) {
return {self.ctx_};
}
};
public:
void run() {
while (auto* op = pop_front_())
op->execute();
}
scheduler get_scheduler() {
return {*this};
}
void finish() {
std::unique_lock lk(mtx_);
finish_ = true;
cv_.notify_all();
}
};
///////////////////////////////////////////
// thread_context execution context
///////////////////////////////////////////
class thread_context : immovable {
run_loop loop_;
std::thread thread_;
public:
thread_context()
: thread_([this]{ loop_.run(); })
{}
void finish() {
loop_.finish();
}
void join() {
thread_.join();
}
auto get_scheduler() {
return loop_.get_scheduler();
}
};
//
// start test code
//
int main() {
thread_context ctx;
std::printf("main thread: %s\n", get_thread_id().c_str());
//just_sender<int> first{42};
auto first = then(schedule(ctx.get_scheduler()), [](auto) { return 42;} );
auto next = then(first, [](int i) {return i+1;} );
auto last = then(next, [](int i) {return i+1;} );
int i = sync_wait(last).value();
std::printf("result: %d\n", i);
ctx.finish();
ctx.join();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment