Last active
October 31, 2021 19:59
-
-
Save harryscholes/294515ae2037d9048cc32259e04e1e2b to your computer and use it in GitHub Desktop.
Rust threadsafe hashmap with notification of state changes via the observer pattern
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::sync::mpsc::{sync_channel, Receiver, SyncSender}; | |
use std::sync::{Arc, Mutex}; | |
use num::Unsigned; | |
struct Price<T> { | |
value: Option<T>, | |
waiters: Option<Vec<SyncSender<T>>>, | |
} | |
impl<T> Price<T> | |
where | |
T: Unsigned + Copy, | |
{ | |
fn new() -> Self { | |
Self { | |
value: None, | |
waiters: None, | |
} | |
} | |
fn from(value: T) -> Self { | |
Self { | |
value: Some(value), | |
waiters: None, | |
} | |
} | |
fn from_waiter(waiter: SyncSender<T>) -> Self { | |
Self { | |
value: None, | |
waiters: Some(vec![waiter]), | |
} | |
} | |
fn add_waiter(&mut self, waiter: SyncSender<T>) { | |
match &mut self.waiters { | |
Some(waiters) => waiters.push(waiter), | |
None => self.waiters = Some(vec![waiter]), | |
} | |
} | |
fn update_price(&mut self, value: T) { | |
self.value = Some(value); | |
self.notify_waiters(value); | |
} | |
fn notify_waiters(&mut self, value: T) { | |
if let Some(waiters) = &self.waiters { | |
for waiter in waiters { | |
waiter.send(value).unwrap(); | |
} | |
self.waiters = None; | |
} | |
} | |
} | |
type PriceReceiver<T> = Receiver<T>; | |
trait PriceHolder<T> { | |
fn put_price(&mut self, symbol: String, value: T); | |
fn get_price(&self, symbol: String) -> Option<T>; | |
fn next_price(&mut self, symbol: String) -> Option<T>; | |
} | |
struct SingleThreaded<T>(HashMap<String, Price<T>>); | |
impl<T> SingleThreaded<T> | |
where | |
T: Unsigned + Copy, | |
{ | |
fn new() -> Self { | |
Self(HashMap::new()) | |
} | |
fn price_receiver(&mut self, symbol: String) -> PriceReceiver<T> { | |
let (tx, rx) = sync_channel(1); | |
match self.0.get_mut(&symbol) { | |
Some(price) => price.add_waiter(tx), | |
None => { | |
self.0.insert(symbol, Price::from_waiter(tx)); | |
} | |
} | |
return rx; | |
} | |
} | |
impl<T> PriceHolder<T> for SingleThreaded<T> | |
where | |
T: Unsigned + Copy, | |
{ | |
fn put_price(&mut self, symbol: String, value: T) { | |
match self.0.get_mut(&symbol) { | |
Some(price) => price.update_price(value), | |
None => { | |
self.0.insert(symbol, Price::from(value)); | |
} | |
}; | |
} | |
fn get_price(&self, symbol: String) -> Option<T> { | |
match self.0.get(&symbol) { | |
Some(price) => price.value, | |
None => None, | |
} | |
} | |
fn next_price(&mut self, symbol: String) -> Option<T> { | |
match self.price_receiver(symbol).recv() { | |
Ok(value) => Some(value), | |
Err(_) => None, | |
} | |
} | |
} | |
#[derive(Clone)] | |
struct ThreadSafe<T>(Arc<Mutex<SingleThreaded<T>>>); | |
impl<T> ThreadSafe<T> | |
where | |
T: Unsigned + Copy, | |
{ | |
fn new() -> Self { | |
Self(Arc::new(Mutex::new(SingleThreaded::new()))) | |
} | |
} | |
impl<T> PriceHolder<T> for ThreadSafe<T> | |
where | |
T: Unsigned + Copy, | |
{ | |
fn put_price(&mut self, symbol: String, value: T) { | |
self.0.lock().unwrap().put_price(symbol, value); | |
} | |
fn get_price(&self, symbol: String) -> Option<T> { | |
self.0.lock().unwrap().get_price(symbol) | |
} | |
fn next_price(&mut self, symbol: String) -> Option<T> { | |
let rx = { self.0.lock().unwrap().price_receiver(symbol) }; // unlock mutex | |
match rx.recv() { | |
Ok(x) => Some(x), | |
Err(_) => None, | |
} | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use std::{thread, time::Duration}; | |
use crate::{PriceHolder, SingleThreaded, ThreadSafe}; | |
#[test] | |
fn put_and_get_price() { | |
let mut ph = SingleThreaded::new(); | |
ph.put_price("symbol".to_string(), 1u32); | |
assert_eq!(ph.get_price("symbol".to_string()).unwrap(), 1); | |
ph.put_price("symbol".to_string(), 2); | |
assert_eq!(ph.get_price("symbol".to_string()).unwrap(), 2); | |
ph.put_price("another_symbol".to_string(), 3); | |
assert_eq!(ph.get_price("another_symbol".to_string()).unwrap(), 3); | |
assert!(ph.get_price("not_a_symbol".to_string()).is_none()); | |
} | |
#[test] | |
fn next_price() { | |
let mut ph = ThreadSafe::new(); | |
ph.put_price("symbol".to_string(), 1u64); | |
let handle = { | |
let mut ph = ph.clone(); | |
thread::spawn(move || { | |
thread::sleep(Duration::from_millis(100)); | |
ph.put_price("symbol".to_string(), 2); | |
}) | |
}; | |
let price = ph.next_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, 2); | |
handle.join().unwrap(); | |
} | |
#[test] | |
fn wait_for_price_of_new_symbol() { | |
let mut ph = ThreadSafe::new(); | |
let handle = { | |
let mut ph = ph.clone(); | |
thread::spawn(move || { | |
thread::sleep(Duration::from_millis(100)); | |
ph.put_price("symbol".to_string(), 2u64); | |
}) | |
}; | |
let price = ph.next_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, 2); | |
handle.join().unwrap(); | |
} | |
#[test] | |
fn multiple_wait_for_next_price() { | |
let mut ph = ThreadSafe::new(); | |
let mut handles = vec![]; | |
for _ in 1..=4 { | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
let price = ph.next_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, 1); | |
}); | |
handles.push(handle); | |
} | |
thread::sleep(Duration::from_millis(100)); | |
ph.put_price("symbol".to_string(), 1u8); | |
for handle in handles { | |
handle.join().unwrap(); | |
} | |
} | |
#[test] | |
fn wait_for_next_price_multiple_times() { | |
let mut ph = ThreadSafe::new(); | |
for p in 1u64..=4 { | |
let handle = { | |
let mut ph = ph.clone(); | |
thread::spawn(move || { | |
let price = ph.next_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, p); | |
}) | |
}; | |
thread::sleep(Duration::from_millis(100)); | |
ph.put_price("symbol".to_string(), p); | |
handle.join().unwrap(); | |
} | |
} | |
#[test] | |
fn next_price_is_the_same() { | |
let mut ph = ThreadSafe::new(); | |
ph.put_price("symbol".to_string(), 1u32); | |
let handle = { | |
let mut ph = ph.clone(); | |
thread::spawn(move || { | |
let price = ph.next_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, 1); | |
}) | |
}; | |
thread::sleep(Duration::from_millis(100)); | |
ph.put_price("symbol".to_string(), 1); | |
handle.join().unwrap(); | |
} | |
#[test] | |
fn get_price_whilst_waiting_for_next_price() { | |
let mut ph = ThreadSafe::new(); | |
ph.put_price("symbol".to_string(), 1u64); | |
{ | |
let mut ph = ph.clone(); | |
thread::spawn(move || { | |
let price = ph.next_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, 2); | |
}); | |
} | |
let handle = { | |
let ph = ph.clone(); | |
thread::spawn(move || { | |
for _ in 0..1_000 { | |
let price = ph.get_price("symbol".to_string()).unwrap(); | |
assert_eq!(price, 1); | |
} | |
}) | |
}; | |
handle.join().unwrap(); | |
ph.put_price("symbol".to_string(), 2); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment