Created
May 15, 2024 03:27
-
-
Save anuradhawick/165d440cc76787552cada8ca75b71aa7 to your computer and use it in GitHub Desktop.
A simple concurrent hash table in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::collections::HashMap; | |
use std::sync::{RwLock, Arc}; | |
use std::sync::atomic::{AtomicU32, Ordering}; | |
use std::cell::UnsafeCell; | |
struct MyHashTable { | |
map: RwLock<UnsafeCell<HashMap<u64, AtomicU32>>>, | |
} | |
impl MyHashTable { | |
fn new() -> MyHashTable { | |
MyHashTable { | |
map: RwLock::new(UnsafeCell::new(HashMap::new())), | |
} | |
} | |
// Adds or updates an entry in the hash table | |
fn insert(&self, key: u64, value: u32) { | |
// Acquire a write lock | |
let write_guard = self.map.write().unwrap(); | |
// SAFETY: We have a write lock, so no other thread can access the map | |
unsafe { | |
let map = &mut *write_guard.get(); | |
map.entry(key) | |
.and_modify(|e| e.store(value, Ordering::SeqCst)) | |
.or_insert_with(|| AtomicU32::new(value)); | |
} | |
} | |
// Retrieves an entry from the hash table | |
fn get(&self, key: u64) -> Option<u32> { | |
// Acquire a read lock | |
let read_guard = self.map.read().unwrap(); | |
// SAFETY: We have a read lock, so no other thread can write to the map | |
unsafe { | |
let map = &*read_guard.get(); | |
map.get(&key).map(|e| e.load(Ordering::SeqCst)) | |
} | |
} | |
// Increments an entry in the hash table | |
fn increment(&self, key: u64) { | |
// Acquire a write lock | |
let write_guard = self.map.write().unwrap(); | |
// SAFETY: We have a write lock, so no other thread can access the map | |
unsafe { | |
let map = &mut *write_guard.get(); | |
if let Some(entry) = map.get(&key) { | |
entry.fetch_add(1, Ordering::SeqCst); | |
} | |
} | |
} | |
} | |
unsafe impl Send for MyHashTable {} | |
unsafe impl Sync for MyHashTable {} | |
fn main() { | |
let table = Arc::new(MyHashTable::new()); | |
// Insert values into the hash table | |
table.insert(1, 10); | |
table.insert(2, 20); | |
// Retrieve values | |
println!("Value for key 1: {:?}", table.get(1)); | |
println!("Value for key 2: {:?}", table.get(2)); | |
// Increment values | |
table.increment(1); | |
println!("Value for key 1 after increment: {:?}", table.get(1)); | |
// Example of concurrent usage | |
let table_clone = Arc::clone(&table); | |
std::thread::spawn(move || { | |
table_clone.insert(3, 30); | |
println!("Value for key 3: {:?}", table_clone.get(3)); | |
}).join().unwrap(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This could be further improved by only doing a write lock for updates, if the key is missing.
But this requires smarter locking with Re-entrant locks (acquire as read, then check key, if absent acquire as write, insert and yield)
https://docs.rs/parking_lot/0.12.2/parking_lot/index.html