Skip to content

Instantly share code, notes, and snippets.

@adriangb
Last active November 27, 2021 10:07
Show Gist options
  • Save adriangb/1352d711966db89b94395e8f6fb83de6 to your computer and use it in GitHub Desktop.
Save adriangb/1352d711966db89b94395e8f6fb83de6 to your computer and use it in GitHub Desktop.
Hashable Python objects in PyO3
use std::cmp;
use std::hash;
use pyo3::basic::CompareOp;
use pyo3::prelude::*;
// We can't put a Py<PyAny> directly into a HashMap key
// So to be able to hold references to arbitrary Python objects in HashMap as keys
// we wrap them in a struct that gets the hash() when it receives the object from Python
// and then just echoes back that hash when called Rust needs to hash it
#[derive(Clone)]
pub struct HashedAny(pub Py<PyAny>, isize);
impl <'source>FromPyObject<'source> for HashedAny
{
fn extract(ob: &'source PyAny) -> PyResult<Self> {
Ok(
HashedAny(ob.into(), ob.hash()?)
)
}
}
impl hash::Hash for HashedAny {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.1.hash(state)
}
}
impl cmp::PartialEq for HashedAny {
fn eq(&self, other: &Self) -> bool {
Python::with_gil(|py| -> PyResult<bool> {
Ok(self.0.as_ref(py).rich_compare(other.0.as_ref(py), CompareOp::Eq)?.is_true()?)
}).unwrap()
}
}
impl cmp::Eq for HashedAny {}
// Just a HashMap wrapper for testing purposes
use std::collections::HashMap;
use pyo3::{Py, PyAny, Python, exceptions};
use pyo3::prelude::*;
mod hashedany;
use crate::hashedany::HashedAny;
#[pyclass]
#[derive(Debug, Clone)]
struct _PyHashMap {
map: HashMap<HashedAny, Py<PyAny>>,
}
#[pymethods]
impl _PyHashMap {
#[new]
fn new(map: HashMap<HashedAny, Py<PyAny>>) -> Self {
_PyHashMap { map }
}
fn __setitem__(&mut self, k: HashedAny, v: Py<PyAny>) -> () {
self.map.insert(k, v);
}
fn __getitem__(&self, k: HashedAny) -> PyResult<Py<PyAny>> {
match self.map.get(&k) {
Some(v) => Ok(v.clone()),
None => Err(exceptions::PyKeyError::new_err(format!("KeyError: {:?}", k))),
}
}
fn __delitem__(&mut self, k: HashedAny) -> PyResult<()> {
match self.map.remove(&k) {
Some(_) => Ok(()),
None => Err(exceptions::PyKeyError::new_err(format!("KeyError: {:?}", k))),
}
}
}
#[pymodule]
fn hashedpyany(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<_PyHashMap>()?;
Ok(())
}
# Python tests using hypothesis
from __future__ import annotations
import unittest
from typing import Any, Dict, Hashable
import hypothesis.strategies as st
from hypothesis.stateful import Bundle, RuleBasedStateMachine, rule
from hashedpyany import PyHashMap
class HashMapComparison(RuleBasedStateMachine):
def __init__(self):
super().__init__()
self.python: Dict[Hashable, Any] = {}
self.rust: PyHashMap[Hashable, Any] = PyHashMap({})
keys = Bundle("keys")
values = Bundle("values")
@rule(target=keys, k=st.tuples(st.integers()))
def add_key(self, k):
return k
@rule(target=values, v=st.binary())
def add_value(self, v):
return v
@rule(k=keys, v=values)
def insert(self, k, v):
self.python[k] = v
self.rust[k] = v
@rule(k=keys)
def get(self, k):
try:
py = self.python[k]
except KeyError:
py = None
try:
rus = self.rust[k]
except KeyError:
rus = None
assert rus == py
@rule(k=keys)
def delete(self, k):
try:
del self.python[k]
py = False
except KeyError:
py = True
try:
del self.rust[k]
rus = False
except KeyError:
rus = True
assert rus == py
TestHashMapComparison = HashMapComparison.TestCase
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment