Skip to content

Instantly share code, notes, and snippets.

@ericniebler
Last active June 9, 2022 22:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericniebler/dd3ed2a7661162909dd523da35eee7ed to your computer and use it in GitHub Desktop.
Save ericniebler/dd3ed2a7661162909dd523da35eee7ed to your computer and use it in GitHub Desktop.
A toy implementation of P2300, the std::execution proposal, for teaching purposes
/*
* Copyright 2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// This is a toy implementation of the core parts of the C++ std::execution
// proposal (aka, Executors, http://wg21.link/P2300). It is intended to be a
// learning tool only. THIS CODE IS NOT SUITABLE FOR ANY USE.
#include <condition_variable>
#include <cstdio>
#include <exception>
#include <functional>
#include <mutex>
#include <optional>
#include <sstream>
#include <thread>
#include <utility>
// Some utility code
///////////////////////////////////////////
std::string get_thread_id() {
std::stringstream sout;
sout.imbue(std::locale::classic());
sout << "0x" << std::hex << std::this_thread::get_id();
return sout.str();
}
struct immovable {
immovable() = default;
immovable(immovable&&) = delete;
};
struct none {};
// In this toy implementation, a sender can only complete with a single value.
template <class Snd>
using sender_result_t = typename Snd::result_t;
///////////////////////////////////////////
// just(T) sender factory
///////////////////////////////////////////
template <class T, class Rcvr>
struct just_operation : immovable {
T value_;
Rcvr rcvr_;
friend void start(just_operation& self) {
set_value(self.rcvr_, self.value_);
}
};
template <class T>
struct just_sender {
using result_t = T;
T value_;
template <class Rcvr>
friend just_operation<T, Rcvr> connect(just_sender self, Rcvr rcvr) {
return {{}, self.value_, rcvr};
}
};
///////////////////////////////////////////
// then(Sender, Function) sender adaptor
///////////////////////////////////////////
template <class Fun, class Rcvr>
struct then_receiver {
Fun fun_;
Rcvr rcvr_;
friend void set_value(then_receiver self, auto val) try {
set_value(self.rcvr_, self.fun_(val));
} catch(...) {
set_error(self.rcvr_, std::current_exception());
}
friend void set_error(then_receiver self, std::exception_ptr err) {
set_error(self.rcvr_, err);
}
friend void set_stopped(then_receiver self) {
set_stopped(self.rcvr_);
}
};
template <class PrevSnd, class Fun>
struct then_sender {
using prev_result_t = sender_result_t<PrevSnd>;
using result_t = std::invoke_result_t<Fun, prev_result_t>;
PrevSnd prev_;
Fun fun_;
template <class Rcvr>
friend auto connect(then_sender self, Rcvr rcvr) {
return connect(self.prev_, then_receiver<Fun, Rcvr>{self.fun_, rcvr});
}
};
template <class PrevSnd, class Fun>
then_sender<PrevSnd, Fun> then(PrevSnd prev, Fun fun) {
return {prev, fun};
}
///////////////////////////////////////////
// sync_wait() sender consumer
///////////////////////////////////////////
struct sync_wait_data {
std::mutex mtx;
std::condition_variable cv;
std::exception_ptr err;
bool completed = false;
};
template <class T>
struct sync_wait_receiver {
sync_wait_data& data_;
std::optional<T>& value_;
friend void set_value(sync_wait_receiver self, T val) {
std::unique_lock lk(self.data_.mtx);
self.value_.emplace(val);
self.data_.completed = true;
self.data_.cv.notify_all();
}
friend void set_error(sync_wait_receiver self, std::exception_ptr err) {
std::unique_lock lk(self.data_.mtx);
self.data_.err = err;
self.data_.completed = true;
self.data_.cv.notify_all();
}
friend void set_stopped(sync_wait_receiver self) {
std::unique_lock lk(self.data_.mtx);
self.data_.completed = true;
self.data_.cv.notify_all();
}
};
template <class Snd>
std::optional<sender_result_t<Snd>> sync_wait(Snd snd) {
using T = sender_result_t<Snd>;
sync_wait_data data;
std::optional<T> value;
auto op = connect(snd, sync_wait_receiver<T>{data, value});
start(op);
std::unique_lock lk{data.mtx};
data.cv.wait(lk, [&data]{ return data.completed; });
if (data.err)
std::rethrow_exception(data.err);
return value;
}
///////////////////////////////////////////
// run_loop execution context
///////////////////////////////////////////
struct run_loop : private immovable {
private:
struct task : private immovable {
task* next_ = this;
virtual void execute() {}
};
template <class Rcvr>
struct operation : task {
Rcvr rcvr_;
run_loop* ctx_;
operation(Rcvr rcvr, run_loop* ctx)
: rcvr_(rcvr), ctx_(ctx) {}
void execute() override final {
std::printf("Running task on thread: %s\n", get_thread_id().c_str());
set_value(rcvr_, none{});
}
friend void start(operation& self) {
self.start_();
}
void start_() {
ctx_->push_back_(this);
}
};
task head_;
task* tail_ = &head_;
bool finish_ = false;
std::mutex mtx_;
std::condition_variable cv_;
void push_back_(task* op) {
std::unique_lock lk(mtx_);
op->next_ = &head_;
tail_ = tail_->next_= op;
cv_.notify_one();
}
task* pop_front_() {
std::unique_lock lk(mtx_);
if (tail_ == head_.next_)
tail_ = &head_;
cv_.wait(lk, [this]{ return head_.next_ != &head_ || finish_; });
if (head_.next_ == &head_)
return nullptr;
return std::exchange(head_.next_, head_.next_->next_);
}
struct sender {
using result_t = none;
run_loop* ctx_;
template <class Rcvr>
friend operation<Rcvr> connect(sender self, Rcvr rcvr) {
return self.connect_(rcvr);
}
template <class Rcvr>
operation<Rcvr> connect_(Rcvr rcvr) {
return {rcvr, ctx_};
}
};
struct scheduler {
run_loop* ctx_;
friend bool operator==(scheduler const&, scheduler const&) = default;
friend sender schedule(scheduler self) {
return {self.ctx_};
}
};
public:
void run() {
while (auto* op = pop_front_())
op->execute();
}
scheduler get_scheduler() {
return {this};
}
void finish() {
std::unique_lock lk(mtx_);
finish_ = true;
cv_.notify_all();
}
};
///////////////////////////////////////////
// thread_context execution context
///////////////////////////////////////////
class thread_context : immovable {
run_loop loop_;
std::thread thread_;
public:
thread_context()
: thread_([this]{ loop_.run(); })
{}
void finish() {
loop_.finish();
}
void join() {
thread_.join();
}
auto get_scheduler() {
return loop_.get_scheduler();
}
};
// //
// // start test code
// //
// int main() {
// thread_context ctx;
// std::printf("main thread: %s\n", get_thread_id().c_str());
// //just_sender<int> first{42};
// auto first = then(schedule(ctx.get_scheduler()), [](auto) { return 42;} );
// auto next = then(first, [](int i) {return i+1;} );
// auto last = then(next, [](int i) {return i+1;} );
// int i = sync_wait(last).value();
// std::printf("result: %d\n", i);
// ctx.finish();
// ctx.join();
// }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment