Skip to content

Instantly share code, notes, and snippets.

@harryscholes
Last active October 31, 2021 19:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save harryscholes/294515ae2037d9048cc32259e04e1e2b to your computer and use it in GitHub Desktop.
Save harryscholes/294515ae2037d9048cc32259e04e1e2b to your computer and use it in GitHub Desktop.
Rust threadsafe hashmap with notification of state changes via the observer pattern
use std::collections::HashMap;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::sync::{Arc, Mutex};
use num::Unsigned;
struct Price<T> {
value: Option<T>,
waiters: Option<Vec<SyncSender<T>>>,
}
impl<T> Price<T>
where
T: Unsigned + Copy,
{
fn new() -> Self {
Self {
value: None,
waiters: None,
}
}
fn from(value: T) -> Self {
Self {
value: Some(value),
waiters: None,
}
}
fn from_waiter(waiter: SyncSender<T>) -> Self {
Self {
value: None,
waiters: Some(vec![waiter]),
}
}
fn add_waiter(&mut self, waiter: SyncSender<T>) {
match &mut self.waiters {
Some(waiters) => waiters.push(waiter),
None => self.waiters = Some(vec![waiter]),
}
}
fn update_price(&mut self, value: T) {
self.value = Some(value);
self.notify_waiters(value);
}
fn notify_waiters(&mut self, value: T) {
if let Some(waiters) = &self.waiters {
for waiter in waiters {
waiter.send(value).unwrap();
}
self.waiters = None;
}
}
}
type PriceReceiver<T> = Receiver<T>;
trait PriceHolder<T> {
fn put_price(&mut self, symbol: String, value: T);
fn get_price(&self, symbol: String) -> Option<T>;
fn next_price(&mut self, symbol: String) -> Option<T>;
}
struct SingleThreaded<T>(HashMap<String, Price<T>>);
impl<T> SingleThreaded<T>
where
T: Unsigned + Copy,
{
fn new() -> Self {
Self(HashMap::new())
}
fn price_receiver(&mut self, symbol: String) -> PriceReceiver<T> {
let (tx, rx) = sync_channel(1);
match self.0.get_mut(&symbol) {
Some(price) => price.add_waiter(tx),
None => {
self.0.insert(symbol, Price::from_waiter(tx));
}
}
return rx;
}
}
impl<T> PriceHolder<T> for SingleThreaded<T>
where
T: Unsigned + Copy,
{
fn put_price(&mut self, symbol: String, value: T) {
match self.0.get_mut(&symbol) {
Some(price) => price.update_price(value),
None => {
self.0.insert(symbol, Price::from(value));
}
};
}
fn get_price(&self, symbol: String) -> Option<T> {
match self.0.get(&symbol) {
Some(price) => price.value,
None => None,
}
}
fn next_price(&mut self, symbol: String) -> Option<T> {
match self.price_receiver(symbol).recv() {
Ok(value) => Some(value),
Err(_) => None,
}
}
}
#[derive(Clone)]
struct ThreadSafe<T>(Arc<Mutex<SingleThreaded<T>>>);
impl<T> ThreadSafe<T>
where
T: Unsigned + Copy,
{
fn new() -> Self {
Self(Arc::new(Mutex::new(SingleThreaded::new())))
}
}
impl<T> PriceHolder<T> for ThreadSafe<T>
where
T: Unsigned + Copy,
{
fn put_price(&mut self, symbol: String, value: T) {
self.0.lock().unwrap().put_price(symbol, value);
}
fn get_price(&self, symbol: String) -> Option<T> {
self.0.lock().unwrap().get_price(symbol)
}
fn next_price(&mut self, symbol: String) -> Option<T> {
let rx = { self.0.lock().unwrap().price_receiver(symbol) }; // unlock mutex
match rx.recv() {
Ok(x) => Some(x),
Err(_) => None,
}
}
}
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use crate::{PriceHolder, SingleThreaded, ThreadSafe};
#[test]
fn put_and_get_price() {
let mut ph = SingleThreaded::new();
ph.put_price("symbol".to_string(), 1u32);
assert_eq!(ph.get_price("symbol".to_string()).unwrap(), 1);
ph.put_price("symbol".to_string(), 2);
assert_eq!(ph.get_price("symbol".to_string()).unwrap(), 2);
ph.put_price("another_symbol".to_string(), 3);
assert_eq!(ph.get_price("another_symbol".to_string()).unwrap(), 3);
assert!(ph.get_price("not_a_symbol".to_string()).is_none());
}
#[test]
fn next_price() {
let mut ph = ThreadSafe::new();
ph.put_price("symbol".to_string(), 1u64);
let handle = {
let mut ph = ph.clone();
thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
ph.put_price("symbol".to_string(), 2);
})
};
let price = ph.next_price("symbol".to_string()).unwrap();
assert_eq!(price, 2);
handle.join().unwrap();
}
#[test]
fn wait_for_price_of_new_symbol() {
let mut ph = ThreadSafe::new();
let handle = {
let mut ph = ph.clone();
thread::spawn(move || {
thread::sleep(Duration::from_millis(100));
ph.put_price("symbol".to_string(), 2u64);
})
};
let price = ph.next_price("symbol".to_string()).unwrap();
assert_eq!(price, 2);
handle.join().unwrap();
}
#[test]
fn multiple_wait_for_next_price() {
let mut ph = ThreadSafe::new();
let mut handles = vec![];
for _ in 1..=4 {
let mut ph = ph.clone();
let handle = thread::spawn(move || {
let price = ph.next_price("symbol".to_string()).unwrap();
assert_eq!(price, 1);
});
handles.push(handle);
}
thread::sleep(Duration::from_millis(100));
ph.put_price("symbol".to_string(), 1u8);
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn wait_for_next_price_multiple_times() {
let mut ph = ThreadSafe::new();
for p in 1u64..=4 {
let handle = {
let mut ph = ph.clone();
thread::spawn(move || {
let price = ph.next_price("symbol".to_string()).unwrap();
assert_eq!(price, p);
})
};
thread::sleep(Duration::from_millis(100));
ph.put_price("symbol".to_string(), p);
handle.join().unwrap();
}
}
#[test]
fn next_price_is_the_same() {
let mut ph = ThreadSafe::new();
ph.put_price("symbol".to_string(), 1u32);
let handle = {
let mut ph = ph.clone();
thread::spawn(move || {
let price = ph.next_price("symbol".to_string()).unwrap();
assert_eq!(price, 1);
})
};
thread::sleep(Duration::from_millis(100));
ph.put_price("symbol".to_string(), 1);
handle.join().unwrap();
}
#[test]
fn get_price_whilst_waiting_for_next_price() {
let mut ph = ThreadSafe::new();
ph.put_price("symbol".to_string(), 1u64);
{
let mut ph = ph.clone();
thread::spawn(move || {
let price = ph.next_price("symbol".to_string()).unwrap();
assert_eq!(price, 2);
});
}
let handle = {
let ph = ph.clone();
thread::spawn(move || {
for _ in 0..1_000 {
let price = ph.get_price("symbol".to_string()).unwrap();
assert_eq!(price, 1);
}
})
};
handle.join().unwrap();
ph.put_price("symbol".to_string(), 2);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment