Last active
June 9, 2022 22:36
-
-
Save ericniebler/dd3ed2a7661162909dd523da35eee7ed to your computer and use it in GitHub Desktop.
A toy implementation of P2300, the std::execution proposal, for teaching purposes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2022 NVIDIA Corporation | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
// This is a toy implementation of the core parts of the C++ std::execution | |
// proposal (aka, Executors, http://wg21.link/P2300). It is intended to be a | |
// learning tool only. THIS CODE IS NOT SUITABLE FOR ANY USE. | |
#include <condition_variable> | |
#include <cstdio> | |
#include <exception> | |
#include <functional> | |
#include <mutex> | |
#include <optional> | |
#include <sstream> | |
#include <thread> | |
#include <utility> | |
// Some utility code | |
/////////////////////////////////////////// | |
std::string get_thread_id() { | |
std::stringstream sout; | |
sout.imbue(std::locale::classic()); | |
sout << "0x" << std::hex << std::this_thread::get_id(); | |
return sout.str(); | |
} | |
struct immovable { | |
immovable() = default; | |
immovable(immovable&&) = delete; | |
}; | |
struct none {}; | |
// In this toy implementation, a sender can only complete with a single value. | |
template <class Snd> | |
using sender_result_t = typename Snd::result_t; | |
/////////////////////////////////////////// | |
// just(T) sender factory | |
/////////////////////////////////////////// | |
template <class T, class Rcvr> | |
struct just_operation : immovable { | |
T value_; | |
Rcvr rcvr_; | |
friend void start(just_operation& self) { | |
set_value(self.rcvr_, self.value_); | |
} | |
}; | |
template <class T> | |
struct just_sender { | |
using result_t = T; | |
T value_; | |
template <class Rcvr> | |
friend just_operation<T, Rcvr> connect(just_sender self, Rcvr rcvr) { | |
return {{}, self.value_, rcvr}; | |
} | |
}; | |
/////////////////////////////////////////// | |
// then(Sender, Function) sender adaptor | |
/////////////////////////////////////////// | |
template <class Fun, class Rcvr> | |
struct then_receiver { | |
Fun fun_; | |
Rcvr rcvr_; | |
friend void set_value(then_receiver self, auto val) try { | |
set_value(self.rcvr_, self.fun_(val)); | |
} catch(...) { | |
set_error(self.rcvr_, std::current_exception()); | |
} | |
friend void set_error(then_receiver self, std::exception_ptr err) { | |
set_error(self.rcvr_, err); | |
} | |
friend void set_stopped(then_receiver self) { | |
set_stopped(self.rcvr_); | |
} | |
}; | |
template <class PrevSnd, class Fun> | |
struct then_sender { | |
using prev_result_t = sender_result_t<PrevSnd>; | |
using result_t = std::invoke_result_t<Fun, prev_result_t>; | |
PrevSnd prev_; | |
Fun fun_; | |
template <class Rcvr> | |
friend auto connect(then_sender self, Rcvr rcvr) { | |
return connect(self.prev_, then_receiver<Fun, Rcvr>{self.fun_, rcvr}); | |
} | |
}; | |
template <class PrevSnd, class Fun> | |
then_sender<PrevSnd, Fun> then(PrevSnd prev, Fun fun) { | |
return {prev, fun}; | |
} | |
/////////////////////////////////////////// | |
// sync_wait() sender consumer | |
/////////////////////////////////////////// | |
struct sync_wait_data { | |
std::mutex mtx; | |
std::condition_variable cv; | |
std::exception_ptr err; | |
bool completed = false; | |
}; | |
template <class T> | |
struct sync_wait_receiver { | |
sync_wait_data& data_; | |
std::optional<T>& value_; | |
friend void set_value(sync_wait_receiver self, T val) { | |
std::unique_lock lk(self.data_.mtx); | |
self.value_.emplace(val); | |
self.data_.completed = true; | |
self.data_.cv.notify_all(); | |
} | |
friend void set_error(sync_wait_receiver self, std::exception_ptr err) { | |
std::unique_lock lk(self.data_.mtx); | |
self.data_.err = err; | |
self.data_.completed = true; | |
self.data_.cv.notify_all(); | |
} | |
friend void set_stopped(sync_wait_receiver self) { | |
std::unique_lock lk(self.data_.mtx); | |
self.data_.completed = true; | |
self.data_.cv.notify_all(); | |
} | |
}; | |
template <class Snd> | |
std::optional<sender_result_t<Snd>> sync_wait(Snd snd) { | |
using T = sender_result_t<Snd>; | |
sync_wait_data data; | |
std::optional<T> value; | |
auto op = connect(snd, sync_wait_receiver<T>{data, value}); | |
start(op); | |
std::unique_lock lk{data.mtx}; | |
data.cv.wait(lk, [&data]{ return data.completed; }); | |
if (data.err) | |
std::rethrow_exception(data.err); | |
return value; | |
} | |
/////////////////////////////////////////// | |
// run_loop execution context | |
/////////////////////////////////////////// | |
struct run_loop : private immovable { | |
private: | |
struct task : private immovable { | |
task* next_ = this; | |
virtual void execute() {} | |
}; | |
template <class Rcvr> | |
struct operation : task { | |
Rcvr rcvr_; | |
run_loop* ctx_; | |
operation(Rcvr rcvr, run_loop* ctx) | |
: rcvr_(rcvr), ctx_(ctx) {} | |
void execute() override final { | |
std::printf("Running task on thread: %s\n", get_thread_id().c_str()); | |
set_value(rcvr_, none{}); | |
} | |
friend void start(operation& self) { | |
self.start_(); | |
} | |
void start_() { | |
ctx_->push_back_(this); | |
} | |
}; | |
task head_; | |
task* tail_ = &head_; | |
bool finish_ = false; | |
std::mutex mtx_; | |
std::condition_variable cv_; | |
void push_back_(task* op) { | |
std::unique_lock lk(mtx_); | |
op->next_ = &head_; | |
tail_ = tail_->next_= op; | |
cv_.notify_one(); | |
} | |
task* pop_front_() { | |
std::unique_lock lk(mtx_); | |
if (tail_ == head_.next_) | |
tail_ = &head_; | |
cv_.wait(lk, [this]{ return head_.next_ != &head_ || finish_; }); | |
if (head_.next_ == &head_) | |
return nullptr; | |
return std::exchange(head_.next_, head_.next_->next_); | |
} | |
struct sender { | |
using result_t = none; | |
run_loop* ctx_; | |
template <class Rcvr> | |
friend operation<Rcvr> connect(sender self, Rcvr rcvr) { | |
return self.connect_(rcvr); | |
} | |
template <class Rcvr> | |
operation<Rcvr> connect_(Rcvr rcvr) { | |
return {rcvr, ctx_}; | |
} | |
}; | |
struct scheduler { | |
run_loop* ctx_; | |
friend bool operator==(scheduler const&, scheduler const&) = default; | |
friend sender schedule(scheduler self) { | |
return {self.ctx_}; | |
} | |
}; | |
public: | |
void run() { | |
while (auto* op = pop_front_()) | |
op->execute(); | |
} | |
scheduler get_scheduler() { | |
return {this}; | |
} | |
void finish() { | |
std::unique_lock lk(mtx_); | |
finish_ = true; | |
cv_.notify_all(); | |
} | |
}; | |
/////////////////////////////////////////// | |
// thread_context execution context | |
/////////////////////////////////////////// | |
class thread_context : immovable { | |
run_loop loop_; | |
std::thread thread_; | |
public: | |
thread_context() | |
: thread_([this]{ loop_.run(); }) | |
{} | |
void finish() { | |
loop_.finish(); | |
} | |
void join() { | |
thread_.join(); | |
} | |
auto get_scheduler() { | |
return loop_.get_scheduler(); | |
} | |
}; | |
// // | |
// // start test code | |
// // | |
// int main() { | |
// thread_context ctx; | |
// std::printf("main thread: %s\n", get_thread_id().c_str()); | |
// //just_sender<int> first{42}; | |
// auto first = then(schedule(ctx.get_scheduler()), [](auto) { return 42;} ); | |
// auto next = then(first, [](int i) {return i+1;} ); | |
// auto last = then(next, [](int i) {return i+1;} ); | |
// int i = sync_wait(last).value(); | |
// std::printf("result: %d\n", i); | |
// ctx.finish(); | |
// ctx.join(); | |
// } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment