Last active
August 11, 2021 15:50
-
-
Save harryscholes/68e5db01545f5e7f442556c7c2b184c6 to your computer and use it in GitHub Desktop.
Thread safe price holder in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::{ | |
collections::HashMap, | |
sync::{ | |
mpsc::{self, RecvError, Sender}, | |
Arc, Mutex, | |
}, | |
}; | |
struct Price { | |
value: Option<u128>, | |
listeners: Vec<Sender<u128>>, | |
} | |
impl Price { | |
fn from(value: u128) -> Self { | |
Price { | |
value: Some(value), | |
listeners: vec![], | |
} | |
} | |
fn new_with_listeners(listeners: Vec<Sender<u128>>) -> Self { | |
Price { | |
value: None, | |
listeners, | |
} | |
} | |
} | |
#[derive(Clone)] | |
struct PriceHolder { | |
prices: Arc<Mutex<HashMap<String, Price>>>, | |
} | |
impl PriceHolder { | |
fn new() -> PriceHolder { | |
Self { | |
prices: Arc::new(Mutex::new(HashMap::new())), | |
} | |
} | |
fn put<S: Into<String> + Clone>(&mut self, symbol: S, price: u128) { | |
let mut map = self.prices.lock().unwrap(); | |
if let Some(p) = map.get(&symbol.clone().into()) { | |
for l in &p.listeners { | |
l.send(price).unwrap(); | |
} | |
} | |
map.insert(symbol.into(), Price::from(price)); | |
} | |
fn get<S: Into<String>>(&self, symbol: S) -> Option<u128> { | |
self.prices.lock().unwrap().get(&symbol.into())?.value | |
} | |
fn wait<S: Into<String>>(&mut self, symbol: S) -> Result<u128, RecvError> { | |
let (tx, rx) = mpsc::channel(); | |
{ | |
let mut map = self.prices.lock().unwrap(); | |
let s: String = symbol.into(); | |
match map.get_mut(&s) { | |
Some(p) => p.listeners.push(tx), | |
None => { | |
map.insert(s, Price::new_with_listeners(vec![tx])); | |
} | |
}; | |
} // mutex unlocks | |
rx.recv() | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use std::{thread, time}; | |
#[test] | |
fn put_price() { | |
let mut ph = PriceHolder::new(); | |
assert!(ph.get("c").is_none()); | |
ph.put("a", 0); | |
assert_eq!(ph.get("a").unwrap(), 0); | |
ph.put("a", 1); | |
assert_eq!(ph.get("a").unwrap(), 1); | |
} | |
#[test] | |
fn get_price() { | |
let mut ph = PriceHolder::new(); | |
ph.put("a", 1); | |
ph.put("a", 2); | |
assert_eq!(ph.get("a").unwrap(), 2); | |
ph.put("b", 3); | |
assert_eq!(ph.get("b").unwrap(), 3); | |
} | |
#[test] | |
fn wait_price() { | |
let ph = PriceHolder::new(); | |
let mut handles = vec![]; | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
assert_eq!(ph.wait("a").unwrap(), 2); | |
}); | |
handles.push(handle); | |
} | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
thread::sleep(time::Duration::from_millis(100)); | |
ph.put("a", 2); | |
}); | |
handles.push(handle); | |
} | |
for handle in handles { | |
handle.join().unwrap() | |
} | |
} | |
#[test] | |
fn wait_price_multiple_waiters() { | |
let ph = PriceHolder::new(); | |
let mut handles = vec![]; | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
thread::sleep(time::Duration::from_millis(100)); | |
ph.put("a", 2); | |
}); | |
handles.push(handle) | |
} | |
for _ in 0..10 { | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
assert_eq!(ph.wait("a").unwrap(), 2); | |
}); | |
handles.push(handle); | |
} | |
for handle in handles { | |
handle.join().unwrap() | |
} | |
} | |
#[test] | |
fn wait_price_multiple_times() { | |
let ph = PriceHolder::new(); | |
let mut handles = vec![]; | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
assert_eq!(ph.wait("a").unwrap(), 2); | |
}); | |
handles.push(handle); | |
} | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
thread::sleep(time::Duration::from_millis(100)); | |
ph.put("a", 2); | |
}); | |
handles.push(handle); | |
} | |
for handle in handles { | |
handle.join().unwrap() | |
} | |
let mut handles = vec![]; | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
assert_eq!(ph.wait("a").unwrap(), 3); | |
}); | |
handles.push(handle); | |
} | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
thread::sleep(time::Duration::from_millis(100)); | |
ph.put("a", 3); | |
}); | |
handles.push(handle); | |
} | |
for handle in handles { | |
handle.join().unwrap() | |
} | |
} | |
#[test] | |
fn wait_price_same_price() { | |
let ph = PriceHolder::new(); | |
let mut handles = vec![]; | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
assert_eq!(ph.wait("a").unwrap(), 1); | |
}); | |
handles.push(handle); | |
} | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
thread::sleep(time::Duration::from_millis(100)); | |
ph.put("a", 1); | |
}); | |
handles.push(handle); | |
} | |
for handle in handles { | |
handle.join().unwrap() | |
} | |
} | |
#[test] | |
fn mutex_read_write() { | |
let ph = PriceHolder::new(); | |
let mut handles = vec![]; | |
{ | |
let ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
for _ in 0..1_000_000 { | |
ph.get("a"); | |
} | |
}); | |
handles.push(handle); | |
} | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
assert_eq!(ph.wait("a").unwrap(), 1); | |
}); | |
handles.push(handle); | |
} | |
{ | |
let mut ph = ph.clone(); | |
let handle = thread::spawn(move || { | |
thread::sleep(time::Duration::from_millis(100)); | |
ph.put("a", 1); | |
}); | |
handles.push(handle); | |
} | |
for handle in handles { | |
handle.join().unwrap() | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment