Last active
August 27, 2021 11:37
-
-
Save kavanmevada/30a959f78d57c546231767d71f0b32b9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::thread; | |
fn main() { | |
let (tx, rx) = mpsc::unbounded::<&str>(); | |
let t1 = thread::spawn(move || { | |
println!("Received1: {:?}", rx.recv()); | |
println!("Received1: {:?}", rx.recv()); | |
}); | |
tx.send("Hello, world!").unwrap(); | |
tx.send("How are you today?").unwrap(); | |
thread::sleep_ms(5000); | |
drop(tx); | |
t1.join().unwrap(); | |
} | |
mod mpsc { | |
use std::{collections::VecDeque, sync::{Arc, Mutex, MutexGuard, atomic::{AtomicBool, AtomicUsize, Ordering}, mpsc::{RecvError, SendError, TrySendError}}, thread, time::{Duration, Instant}}; | |
type SignalVec<T> = VecDeque<Arc<Hook<T, dyn Signal>>>; | |
pub trait Signal: Send + Sync + 'static { | |
fn fire(&self) -> bool; | |
// fn as_any(&self) -> &(dyn Any + 'static); | |
fn as_ptr(&self) -> *const (); | |
} | |
impl SyncSignal { | |
pub fn wait(&self) { thread::park(); } | |
pub fn wait_timeout(&self, dur: Duration) { thread::park_timeout(dur); } | |
} | |
pub struct SyncSignal(thread::Thread); | |
impl Signal for SyncSignal { | |
fn fire(&self) -> bool { | |
self.0.unpark(); | |
false | |
} | |
// fn as_any(&self) -> &(dyn Any + 'static) { self } | |
fn as_ptr(&self) -> *const () { self as *const _ as *const () } | |
} | |
impl Default for SyncSignal { | |
fn default() -> Self { | |
Self(thread::current()) | |
} | |
} | |
struct Hook<T, S: ?Sized>(Option<Mutex<Option<T>>>, S); | |
impl<T, S: ?Sized + Signal> Hook<T, S> { | |
pub fn slot(msg: Option<T>, signal: S) -> Arc<Self> where S: Sized { | |
Arc::new(Self(Some(Mutex::new(msg)), signal)) | |
} | |
pub fn try_take(&self) -> Option<T> { | |
self.0.as_ref().and_then(|s| s.lock().unwrap().take()) | |
} | |
pub fn fire_send(&self, msg: T) -> (Option<T>, &S) { | |
let ret = match &self.0 { | |
Some(hook) => { | |
*hook.lock().unwrap() = Some(msg); | |
None | |
}, | |
None => Some(msg), | |
}; | |
(ret, &self.1) | |
} | |
pub fn fire_recv(&self) -> (T, &S) { | |
let msg = self.0.as_ref().unwrap().lock().unwrap().take().unwrap(); | |
(msg, &self.1) | |
} | |
} | |
impl<T> Hook<T, SyncSignal> { | |
pub fn wait_send(&self, abort: &AtomicBool) { | |
loop { | |
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg | |
if disconnected || self.0.as_ref().unwrap().lock().is_err() { | |
break; | |
} | |
self.1.wait(); | |
} | |
} | |
pub fn wait_recv(&self, abort: &AtomicBool) -> Option<T> { | |
loop { | |
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg | |
let msg = self.0.as_ref().unwrap().lock().unwrap().take(); | |
if let Some(msg) = msg { | |
break Some(msg); | |
} else if disconnected { | |
break None; | |
} else { | |
self.1.wait() | |
} | |
} | |
} | |
// Err(true) if timeout | |
pub fn wait_deadline_send(&self, abort: &AtomicBool, deadline: Instant) -> Result<(), bool> { | |
loop { | |
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg | |
if self.0.as_ref().unwrap().lock().is_err() { | |
break Ok(()); | |
} else if disconnected { | |
break Err(false); | |
} else if let Some(dur) = deadline.checked_duration_since(Instant::now()) { | |
self.1.wait_timeout(dur); | |
} else { | |
break Err(true); | |
} | |
} | |
} | |
pub fn wait_deadline_recv(&self, abort: &AtomicBool, deadline: Instant) -> Result<T, bool> { | |
loop { | |
let disconnected = abort.load(Ordering::SeqCst); // Check disconnect *before* msg | |
let msg = self.0.as_ref().unwrap().lock().unwrap().take(); | |
if let Some(msg) = msg { | |
break Ok(msg); | |
} else if disconnected { | |
break Err(false); | |
} else if let Some(dur) = deadline.checked_duration_since(Instant::now()) { | |
self.1.wait_timeout(dur); | |
} else { | |
break Err(true); | |
} | |
} | |
} | |
} | |
#[derive(Clone)] | |
struct Chan<T> { | |
sending: Option<(usize, SignalVec<T>)>, | |
queue: VecDeque<T>, | |
waiting: SignalVec<T>, | |
} | |
impl<T> Chan<T> { | |
fn pull_pending(&mut self, pull_extra: bool) { | |
if let Some((cap, sending)) = &mut self.sending { | |
let effective_cap = *cap + pull_extra as usize; | |
while self.queue.len() < effective_cap { | |
if let Some(s) = sending.pop_front() { | |
let (msg, signal) = s.fire_recv(); | |
signal.fire(); | |
self.queue.push_back(msg); | |
} else { | |
break; | |
} | |
} | |
} | |
} | |
} | |
struct Shared<T> { | |
chan: Mutex<Chan<T>>, | |
disconnected: AtomicBool, | |
sender_count: AtomicUsize, | |
receiver_count: AtomicUsize, | |
} | |
impl<T> Shared<T> { | |
fn new(cap: Option<usize>) -> Self { | |
Self { | |
chan: Mutex::new(Chan { | |
sending: cap.map(|cap| (cap, VecDeque::new())), | |
queue: VecDeque::new(), | |
waiting: VecDeque::new(), | |
}), | |
disconnected: AtomicBool::new(false), | |
sender_count: AtomicUsize::new(1), | |
receiver_count: AtomicUsize::new(1), | |
} | |
} | |
fn send( | |
&self, | |
msg: T, | |
block: Option<Option<Instant>>, | |
) -> Result<(), TrySendTimeoutError<T>> { | |
let should_block = block.is_some(); | |
let make_signal = |msg| Hook::slot(Some(msg), SyncSignal::default()); | |
let do_block = |hook: Arc<Hook<T, SyncSignal>>| if let Some(deadline) = block.unwrap() { | |
hook.wait_deadline_send(&self.disconnected, deadline) | |
.or_else(|timed_out| { | |
if timed_out { // Remove our signal | |
let hook: Arc<Hook<T, dyn Signal>> = hook.clone(); | |
wait_lock(&self.chan).sending | |
.as_mut() | |
.unwrap().1 | |
.retain(|s| s.1.as_ptr() != hook.1.as_ptr()); | |
} | |
hook.try_take().map(|msg| if self.is_disconnected() { | |
Err(TrySendTimeoutError::Disconnected(msg)) | |
} else { | |
Err(TrySendTimeoutError::Timeout(msg)) | |
}) | |
.unwrap_or(Ok(())) | |
}) | |
} else { | |
hook.wait_send(&self.disconnected); | |
match hook.try_take() { | |
Some(msg) => Err(TrySendTimeoutError::Disconnected(msg)), | |
None => Ok(()), | |
} | |
}; | |
let mut chan = wait_lock(&self.chan); | |
if self.is_disconnected() { | |
Err(TrySendTimeoutError::Disconnected(msg)).into() | |
} else if !chan.waiting.is_empty() { | |
let mut msg = Some(msg); | |
loop { | |
let slot = chan.waiting.pop_front(); | |
match slot.as_ref().map(|r| r.fire_send(msg.take().unwrap())) { | |
// No more waiting receivers and msg in queue, so break out of the loop | |
None if msg.is_none() => break, | |
// No more waiting receivers, so add msg to queue and break out of the loop | |
None => { | |
chan.queue.push_front(msg.unwrap()); | |
break; | |
} | |
Some((Some(m), signal)) => { | |
if signal.fire() { | |
// Was async and a stream, so didn't acquire the message. Wake another | |
// receiver, and do not yet push the message. | |
msg.replace(m); | |
continue; | |
} else { | |
// Was async and not a stream, so it did acquire the message. Push the | |
// message to the queue for it to be received. | |
chan.queue.push_front(m); | |
drop(chan); | |
break; | |
} | |
}, | |
Some((None, signal)) => { | |
drop(chan); | |
signal.fire(); | |
break; // Was sync, so it has acquired the message | |
}, | |
} | |
} | |
Ok(()).into() | |
} else if chan.sending.as_ref().map(|(cap, _)| chan.queue.len() < *cap).unwrap_or(true) { | |
chan.queue.push_back(msg); | |
Ok(()).into() | |
} else if should_block { // Only bounded from here on | |
let hook = make_signal(msg); | |
chan.sending.as_mut().unwrap().1.push_back(hook.clone()); | |
drop(chan); | |
do_block(hook) | |
} else { | |
Err(TrySendTimeoutError::Full(msg)).into() | |
} | |
} | |
fn recv(&self, block: Option<Option<Instant>>) -> Result<T, TryRecvTimeoutError> { | |
let should_block = block.is_some(); | |
let make_signal = || Hook::slot(None, SyncSignal::default()); | |
let do_block = |hook: Arc<Hook<T, SyncSignal>>| if let Some(deadline) = block.unwrap() { | |
hook.wait_deadline_recv(&self.disconnected, deadline) | |
.or_else(|timed_out| { | |
if timed_out { // Remove our signal | |
let hook: Arc<Hook<T, dyn Signal>> = hook.clone(); | |
wait_lock(&self.chan).waiting | |
.retain(|s| s.1.as_ptr() != hook.1.as_ptr()); | |
} | |
match hook.try_take() { | |
Some(msg) => Ok(msg), | |
None => { | |
let disconnected = self.is_disconnected(); // Check disconnect *before* msg | |
if let Some(msg) = wait_lock(&self.chan).queue.pop_front() { | |
Ok(msg) | |
} else if disconnected { | |
Err(TryRecvTimeoutError::Disconnected) | |
} else { | |
Err(TryRecvTimeoutError::Timeout) | |
} | |
}, | |
} | |
}) | |
} else { | |
hook.wait_recv(&self.disconnected) | |
.or_else(|| wait_lock(&self.chan).queue.pop_front()) | |
.ok_or(TryRecvTimeoutError::Disconnected) | |
}; | |
let mut chan = wait_lock(&self.chan); | |
chan.pull_pending(true); | |
if let Some(msg) = chan.queue.pop_front() { | |
drop(chan); | |
Ok(msg).into() | |
} else if self.is_disconnected() { | |
drop(chan); | |
Err(TryRecvTimeoutError::Disconnected).into() | |
} else if should_block { | |
let hook = make_signal(); | |
chan.waiting.push_back(hook.clone()); | |
drop(chan); | |
do_block(hook) | |
} else { | |
drop(chan); | |
Err(TryRecvTimeoutError::Empty).into() | |
} | |
} | |
fn is_disconnected(&self) -> bool { | |
self.disconnected.load(Ordering::SeqCst) | |
} | |
} | |
pub struct Sender<T> { | |
shared: Arc<Shared<T>>, | |
} | |
impl<T> Sender<T> { | |
pub fn send(&self, msg: T) -> Result<(), SendError<T>> { | |
self.shared.send(msg, Some(None)).map_err(|err| match err { | |
TrySendTimeoutError::Disconnected(msg) => SendError(msg), | |
_ => unreachable!(), | |
}) | |
} | |
} | |
pub struct Receiver<T> { | |
shared: Arc<Shared<T>>, | |
} | |
impl<T> Receiver<T> { | |
pub fn recv(&self) -> Result<T, RecvError> { | |
self.shared.recv(Some(None)).map_err(|err| match err { | |
TryRecvTimeoutError::Disconnected => RecvError, | |
_ => unreachable!(), | |
}) | |
} | |
} | |
impl<T> Clone for Receiver<T> { | |
fn clone(&self) -> Self { | |
self.shared.receiver_count.fetch_add(1, Ordering::Relaxed); | |
Self { shared: self.shared.clone() } | |
} | |
} | |
pub fn unbounded<T>() -> (Sender<T>, Receiver<T>) { | |
let shared = Arc::new(Shared::new(None)); | |
( | |
Sender { shared: shared.clone() }, | |
Receiver { shared }, | |
) | |
} | |
fn wait_lock<T>(lock: &Mutex<T>) -> MutexGuard<T> { | |
let mut i = 4; | |
loop { | |
for _ in 0..10 { | |
if let Ok(guard) = lock.try_lock() { | |
return guard; | |
} | |
thread::yield_now(); | |
} | |
// Sleep for at most ~1 ms | |
thread::sleep(Duration::from_nanos(1 << i.min(20))); | |
i += 1; | |
} | |
} | |
enum TrySendTimeoutError<T> { | |
Full(T), | |
Disconnected(T), | |
Timeout(T), | |
} | |
enum TryRecvTimeoutError { | |
Empty, | |
Timeout, | |
Disconnected, | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#![feature(test)] | |
extern crate test; | |
use std::{thread, time::Duration}; | |
mod mpsc2 { | |
use std::{collections::VecDeque, sync::{Arc, Mutex}, thread}; | |
#[derive(Debug)] | |
struct Hook<T, S: ?Sized>(Mutex<Option<T>>, S); | |
#[derive(Debug)] | |
struct Parker(thread::Thread); | |
impl Parker { | |
fn park(&self) { thread::park() } | |
fn unpark(&self) { self.0.unpark() } | |
} | |
#[derive(Debug)] | |
struct Channel<T> { | |
queue: VecDeque<T>, | |
waiting: VecDeque<Arc<Hook<T, Parker>>>, | |
} | |
#[derive(Debug)] | |
pub struct Context<T>(Mutex<Channel<T>>); | |
impl<T> Context<T> { | |
pub fn recv(&self) -> T { | |
let mut channel = self.0.lock().unwrap(); | |
// Look for things in queue | |
if let Some(msg) = channel.queue.pop_front() { | |
msg | |
} else { | |
let hook = Arc::from(Hook(Mutex::new(None), Parker(thread::current()))); | |
channel.waiting.push_back(hook.clone()); | |
drop(channel); | |
hook.1.park(); | |
self.recv() | |
} | |
} | |
pub fn send(&self, msg: T) { | |
let mut channel = self.0.lock().unwrap(); | |
if let Some(slot) = channel.waiting.pop_front() { | |
let msg = match slot.0.lock().unwrap().take() { | |
Some(m) => m, | |
None => msg | |
}; | |
channel.queue.push_front(msg); | |
drop(channel); | |
slot.1.unpark(); | |
} else { | |
channel.queue.push_back(msg); | |
} | |
} | |
} | |
pub fn unbounded<T>() -> Arc<Context<T>> { | |
Arc::from(Context(Mutex::new(Channel::<T> { | |
queue: VecDeque::new(), | |
waiting: VecDeque::new(), | |
}))) | |
} | |
} | |
#[bench] | |
fn bench_mympsc(b: &mut test::Bencher) { | |
b.iter(|| { | |
let start = std::time::Instant::now(); | |
let shared = mpsc2::unbounded(); | |
let shared_ = shared.clone(); | |
thread::spawn(move || { | |
shared_.send("Message 1 from T1 thread!"); | |
shared_.send("Message 2 from T1 thread!"); | |
}); | |
let shared_ = shared.clone(); | |
thread::spawn(move || { | |
shared_.send("Message 1 from T2 thread!"); | |
shared_.send("Message 2 from T2 thread!"); | |
}); | |
for _ in 0..4 { | |
println!("Recieved: {:?}", shared.recv()); | |
} | |
start.elapsed() | |
}); | |
} | |
#[bench] | |
fn bench_flume(b: &mut test::Bencher) { | |
b.iter(|| { | |
let start = std::time::Instant::now(); | |
let (tx, rx) = flume::unbounded(); | |
let shared_ = tx.clone(); | |
thread::spawn(move || { | |
shared_.send("Message 1 from T1 thread!"); | |
shared_.send("Message 2 from T1 thread!"); | |
}); | |
let shared_ = tx.clone(); | |
thread::spawn(move || { | |
shared_.send("Message 1 from T2 thread!"); | |
shared_.send("Message 2 from T2 thread!"); | |
}); | |
for _ in 0..4 { | |
println!("Recieved: {:?}", rx.recv()); | |
} | |
start.elapsed() | |
}); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment