Skip to content

Instantly share code, notes, and snippets.

@kavanmevada
Last active August 27, 2021 11:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kavanmevada/30a959f78d57c546231767d71f0b32b9 to your computer and use it in GitHub Desktop.
Save kavanmevada/30a959f78d57c546231767d71f0b32b9 to your computer and use it in GitHub Desktop.
use std::thread;
fn main() {
let (tx, rx) = mpsc::unbounded::<&str>();
let t1 = thread::spawn(move || {
println!("Received1: {:?}", rx.recv());
println!("Received1: {:?}", rx.recv());
});
tx.send("Hello, world!").unwrap();
tx.send("How are you today?").unwrap();
thread::sleep_ms(5000);
drop(tx);
t1.join().unwrap();
}
mod mpsc {
use std::{collections::VecDeque, sync::{Arc, Mutex, MutexGuard, atomic::{AtomicBool, AtomicUsize, Ordering}, mpsc::{RecvError, SendError, TrySendError}}, thread, time::{Duration, Instant}};
type SignalVec<T> = VecDeque<Arc<Hook<T, dyn Signal>>>;
pub trait Signal: Send + Sync + 'static {
fn fire(&self) -> bool;
// fn as_any(&self) -> &(dyn Any + 'static);
fn as_ptr(&self) -> *const ();
}
impl SyncSignal {
pub fn wait(&self) { thread::park(); }
pub fn wait_timeout(&self, dur: Duration) { thread::park_timeout(dur); }
}
pub struct SyncSignal(thread::Thread);
impl Signal for SyncSignal {
fn fire(&self) -> bool {
self.0.unpark();
false
}
// fn as_any(&self) -> &(dyn Any + 'static) { self }
fn as_ptr(&self) -> *const () { self as *const _ as *const () }
}
impl Default for SyncSignal {
fn default() -> Self {
Self(thread::current())
}
}
struct Hook<T, S: ?Sized>(Option<Mutex<Option<T>>>, S);
impl<T, S: ?Sized + Signal> Hook<T, S> {
pub fn slot(msg: Option<T>, signal: S) -> Arc<Self> where S: Sized {
Arc::new(Self(Some(Mutex::new(msg)), signal))
}
pub fn try_take(&self) -> Option<T> {
self.0.as_ref().and_then(|s| s.lock().unwrap().take())
}
pub fn fire_send(&self, msg: T) -> (Option<T>, &S) {
let ret = match &self.0 {
Some(hook) => {
*hook.lock().unwrap() = Some(msg);
None
},
None => Some(msg),
};
(ret, &self.1)
}
pub fn fire_recv(&self) -> (T, &S) {
let msg = self.0.as_ref().unwrap().lock().unwrap().take().unwrap();
(msg, &self.1)
}
}
impl<T> Hook<T, SyncSignal> {
pub fn wait_send(&self, abort: &AtomicBool) {
loop {
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg
if disconnected || self.0.as_ref().unwrap().lock().is_err() {
break;
}
self.1.wait();
}
}
pub fn wait_recv(&self, abort: &AtomicBool) -> Option<T> {
loop {
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg
let msg = self.0.as_ref().unwrap().lock().unwrap().take();
if let Some(msg) = msg {
break Some(msg);
} else if disconnected {
break None;
} else {
self.1.wait()
}
}
}
// Err(true) if timeout
pub fn wait_deadline_send(&self, abort: &AtomicBool, deadline: Instant) -> Result<(), bool> {
loop {
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg
if self.0.as_ref().unwrap().lock().is_err() {
break Ok(());
} else if disconnected {
break Err(false);
} else if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
self.1.wait_timeout(dur);
} else {
break Err(true);
}
}
}
pub fn wait_deadline_recv(&self, abort: &AtomicBool, deadline: Instant) -> Result<T, bool> {
loop {
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg
let msg = self.0.as_ref().unwrap().lock().unwrap().take();
if let Some(msg) = msg {
break Ok(msg);
} else if disconnected {
break Err(false);
} else if let Some(dur) = deadline.checked_duration_since(Instant::now()) {
self.1.wait_timeout(dur);
} else {
break Err(true);
}
}
}
}
#[derive(Clone)]
struct Chan<T> {
sending: Option<(usize, SignalVec<T>)>,
queue: VecDeque<T>,
waiting: SignalVec<T>,
}
impl<T> Chan<T> {
fn pull_pending(&mut self, pull_extra: bool) {
if let Some((cap, sending)) = &mut self.sending {
let effective_cap = *cap + pull_extra as usize;
while self.queue.len() < effective_cap {
if let Some(s) = sending.pop_front() {
let (msg, signal) = s.fire_recv();
signal.fire();
self.queue.push_back(msg);
} else {
break;
}
}
}
}
}
struct Shared<T> {
chan: Mutex<Chan<T>>,
disconnected: AtomicBool,
sender_count: AtomicUsize,
receiver_count: AtomicUsize,
}
impl<T> Shared<T> {
fn new(cap: Option<usize>) -> Self {
Self {
chan: Mutex::new(Chan {
sending: cap.map(|cap| (cap, VecDeque::new())),
queue: VecDeque::new(),
waiting: VecDeque::new(),
}),
disconnected: AtomicBool::new(false),
sender_count: AtomicUsize::new(1),
receiver_count: AtomicUsize::new(1),
}
}
fn send(
&self,
msg: T,
block: Option<Option<Instant>>,
) -> Result<(), TrySendTimeoutError<T>> {
let should_block = block.is_some();
let make_signal = |msg| Hook::slot(Some(msg), SyncSignal::default());
let do_block = |hook: Arc<Hook<T, SyncSignal>>| if let Some(deadline) = block.unwrap() {
hook.wait_deadline_send(&self.disconnected, deadline)
.or_else(|timed_out| {
if timed_out { // Remove our signal
let hook: Arc<Hook<T, dyn Signal>> = hook.clone();
wait_lock(&self.chan).sending
.as_mut()
.unwrap().1
.retain(|s| s.1.as_ptr() != hook.1.as_ptr());
}
hook.try_take().map(|msg| if self.is_disconnected() {
Err(TrySendTimeoutError::Disconnected(msg))
} else {
Err(TrySendTimeoutError::Timeout(msg))
})
.unwrap_or(Ok(()))
})
} else {
hook.wait_send(&self.disconnected);
match hook.try_take() {
Some(msg) => Err(TrySendTimeoutError::Disconnected(msg)),
None => Ok(()),
}
};
let mut chan = wait_lock(&self.chan);
if self.is_disconnected() {
Err(TrySendTimeoutError::Disconnected(msg)).into()
} else if !chan.waiting.is_empty() {
let mut msg = Some(msg);
loop {
let slot = chan.waiting.pop_front();
match slot.as_ref().map(|r| r.fire_send(msg.take().unwrap())) {
// No more waiting receivers and msg in queue, so break out of the loop
None if msg.is_none() => break,
// No more waiting receivers, so add msg to queue and break out of the loop
None => {
chan.queue.push_front(msg.unwrap());
break;
}
Some((Some(m), signal)) => {
if signal.fire() {
// Was async and a stream, so didn't acquire the message. Wake another
// receiver, and do not yet push the message.
msg.replace(m);
continue;
} else {
// Was async and not a stream, so it did acquire the message. Push the
// message to the queue for it to be received.
chan.queue.push_front(m);
drop(chan);
break;
}
},
Some((None, signal)) => {
drop(chan);
signal.fire();
break; // Was sync, so it has acquired the message
},
}
}
Ok(()).into()
} else if chan.sending.as_ref().map(|(cap, _)| chan.queue.len() < *cap).unwrap_or(true) {
chan.queue.push_back(msg);
Ok(()).into()
} else if should_block { // Only bounded from here on
let hook = make_signal(msg);
chan.sending.as_mut().unwrap().1.push_back(hook.clone());
drop(chan);
do_block(hook)
} else {
Err(TrySendTimeoutError::Full(msg)).into()
}
}
fn recv(&self, block: Option<Option<Instant>>) -> Result<T, TryRecvTimeoutError> {
let should_block = block.is_some();
let make_signal = || Hook::slot(None, SyncSignal::default());
let do_block = |hook: Arc<Hook<T, SyncSignal>>| if let Some(deadline) = block.unwrap() {
hook.wait_deadline_recv(&self.disconnected, deadline)
.or_else(|timed_out| {
if timed_out { // Remove our signal
let hook: Arc<Hook<T, dyn Signal>> = hook.clone();
wait_lock(&self.chan).waiting
.retain(|s| s.1.as_ptr() != hook.1.as_ptr());
}
match hook.try_take() {
Some(msg) => Ok(msg),
None => {
let disconnected = self.is_disconnected(); // Check disconnect *before* msg
if let Some(msg) = wait_lock(&self.chan).queue.pop_front() {
Ok(msg)
} else if disconnected {
Err(TryRecvTimeoutError::Disconnected)
} else {
Err(TryRecvTimeoutError::Timeout)
}
},
}
})
} else {
hook.wait_recv(&self.disconnected)
.or_else(|| wait_lock(&self.chan).queue.pop_front())
.ok_or(TryRecvTimeoutError::Disconnected)
};
let mut chan = wait_lock(&self.chan);
chan.pull_pending(true);
if let Some(msg) = chan.queue.pop_front() {
drop(chan);
Ok(msg).into()
} else if self.is_disconnected() {
drop(chan);
Err(TryRecvTimeoutError::Disconnected).into()
} else if should_block {
let hook = make_signal();
chan.waiting.push_back(hook.clone());
drop(chan);
do_block(hook)
} else {
drop(chan);
Err(TryRecvTimeoutError::Empty).into()
}
}
fn is_disconnected(&self) -> bool {
self.disconnected.load(Ordering::SeqCst)
}
}
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
impl<T> Sender<T> {
pub fn send(&self, msg: T) -> Result<(), SendError<T>> {
self.shared.send(msg, Some(None)).map_err(|err| match err {
TrySendTimeoutError::Disconnected(msg) => SendError(msg),
_ => unreachable!(),
})
}
}
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
}
impl<T> Receiver<T> {
pub fn recv(&self) -> Result<T, RecvError> {
self.shared.recv(Some(None)).map_err(|err| match err {
TryRecvTimeoutError::Disconnected => RecvError,
_ => unreachable!(),
})
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
self.shared.receiver_count.fetch_add(1, Ordering::Relaxed);
Self { shared: self.shared.clone() }
}
}
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared::new(None));
(
Sender { shared: shared.clone() },
Receiver { shared },
)
}
fn wait_lock<T>(lock: &Mutex<T>) -> MutexGuard<T> {
let mut i = 4;
loop {
for _ in 0..10 {
if let Ok(guard) = lock.try_lock() {
return guard;
}
thread::yield_now();
}
// Sleep for at most ~1 ms
thread::sleep(Duration::from_nanos(1 << i.min(20)));
i += 1;
}
}
enum TrySendTimeoutError<T> {
Full(T),
Disconnected(T),
Timeout(T),
}
enum TryRecvTimeoutError {
Empty,
Timeout,
Disconnected,
}
}
#![feature(test)]
extern crate test;
use std::{thread, time::Duration};
mod mpsc2 {
use std::{collections::VecDeque, sync::{Arc, Mutex}, thread};
#[derive(Debug)]
struct Hook<T, S: ?Sized>(Mutex<Option<T>>, S);
#[derive(Debug)]
struct Parker(thread::Thread);
impl Parker {
fn park(&self) { thread::park() }
fn unpark(&self) { self.0.unpark() }
}
#[derive(Debug)]
struct Channel<T> {
queue: VecDeque<T>,
waiting: VecDeque<Arc<Hook<T, Parker>>>,
}
#[derive(Debug)]
pub struct Context<T>(Mutex<Channel<T>>);
impl<T> Context<T> {
pub fn recv(&self) -> T {
let mut channel = self.0.lock().unwrap();
// Look for things in queue
if let Some(msg) = channel.queue.pop_front() {
msg
} else {
let hook = Arc::from(Hook(Mutex::new(None), Parker(thread::current())));
channel.waiting.push_back(hook.clone());
drop(channel);
hook.1.park();
self.recv()
}
}
pub fn send(&self, msg: T) {
let mut channel = self.0.lock().unwrap();
if let Some(slot) = channel.waiting.pop_front() {
let msg = match slot.0.lock().unwrap().take() {
Some(m) => m,
None => msg
};
channel.queue.push_front(msg);
drop(channel);
slot.1.unpark();
} else {
channel.queue.push_back(msg);
}
}
}
pub fn unbounded<T>() -> Arc<Context<T>> {
Arc::from(Context(Mutex::new(Channel::<T> {
queue: VecDeque::new(),
waiting: VecDeque::new(),
})))
}
}
#[bench]
fn bench_mympsc(b: &mut test::Bencher) {
b.iter(|| {
let start = std::time::Instant::now();
let shared = mpsc2::unbounded();
let shared_ = shared.clone();
thread::spawn(move || {
shared_.send("Message 1 from T1 thread!");
shared_.send("Message 2 from T1 thread!");
});
let shared_ = shared.clone();
thread::spawn(move || {
shared_.send("Message 1 from T2 thread!");
shared_.send("Message 2 from T2 thread!");
});
for _ in 0..4 {
println!("Recieved: {:?}", shared.recv());
}
start.elapsed()
});
}
#[bench]
fn bench_flume(b: &mut test::Bencher) {
b.iter(|| {
let start = std::time::Instant::now();
let (tx, rx) = flume::unbounded();
let shared_ = tx.clone();
thread::spawn(move || {
shared_.send("Message 1 from T1 thread!");
shared_.send("Message 2 from T1 thread!");
});
let shared_ = tx.clone();
thread::spawn(move || {
shared_.send("Message 1 from T2 thread!");
shared_.send("Message 2 from T2 thread!");
});
for _ in 0..4 {
println!("Recieved: {:?}", rx.recv());
}
start.elapsed()
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment