Skip to content

Instantly share code, notes, and snippets.

@adriangb
Last active November 24, 2021 08:59
Show Gist options
  • Save adriangb/9d4561fa9ac04eb59dafc324b0550a60 to your computer and use it in GitHub Desktop.
Save adriangb/9d4561fa9ac04eb59dafc324b0550a60 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
use std::cmp;
use std::hash;
use std::fmt;
use pyo3::basic::CompareOp;
use pyo3::prelude::*;
// We can't put a Py<PyAny> directly into a HashMap key
// So to be able to hold references to arbitrary Python objects in HashMap as keys
// we wrap them in a struct that gets the hash() when it receives the object from Python
// and then just echoes back that hash when called Rust needs to hash it
#[derive(Clone)]
pub struct HashedAny {
pub o: Py<PyAny>,
pub hash: isize,
}
// Use the result of calling repr() on the Python object as the debug string value
impl fmt::Debug for HashedAny {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Python::with_gil(|py| -> PyResult<fmt::Result> {
let obj = self.o.as_ref(py);
let pystr = obj.repr()?;
let ruststr = pystr.to_str()?;
Ok(write!(f, "{}", ruststr))
}).unwrap()
}
}
impl <'source>FromPyObject<'source> for HashedAny
{
fn extract(ob: &'source PyAny) -> PyResult<Self> {
Ok(HashedAny{ o: ob.into(), hash: ob.hash()? })
}
}
impl hash::Hash for HashedAny {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.hash.hash(state)
}
}
impl cmp::PartialEq for HashedAny {
fn eq(&self, other: &Self) -> bool {
Python::with_gil(|py| -> PyResult<bool> {
let this_ref = self.o.as_ref(py);
let other_ref = other.o.as_ref(py);
if this_ref.eq(other_ref) {
Ok(true)
}
else {
Ok(this_ref.rich_compare(other_ref, CompareOp::Eq)?.is_true()?)
}
}).unwrap()
}
}
impl cmp::Eq for HashedAny {}
use std::collections::HashMap;
use pyo3::prelude::*;
use pyo3::exceptions;
use pyo3::types::PyTuple;
mod hashedany;
use crate::hashedany::HashedAny;
#[pyclass(module="di_lib",subclass)]
#[derive(Debug,Clone)]
struct Graph {
children: HashMap<HashedAny, Vec<HashedAny>>,
parents: HashMap<HashedAny, Vec<HashedAny>>,
child_counts: HashMap<HashedAny, usize>,
ready_nodes: Vec<Py<PyAny>>,
not_done_count: usize,
}
impl Graph {
fn remove_node(&mut self, node: &HashedAny, to_remove: &mut Vec<HashedAny>) -> () {
match self.child_counts.remove(&node) {
Some(_) => (),
// This node was already removed
// This happens if parents and children are passed in the nodes argument
None => return,
}
// Find all parents and reduce their dependency count by one
match self.parents.remove(&node) {
Some(parents) => {
for parent in parents {
match self.child_counts.get_mut(&parent) {
Some(v) => {
*v -= 1;
},
// This node was already removed
// This happens if parents and children are passed in the nodes argument
None => continue,
}
}
},
// this node was already removed
None => return,
};
// Push all children onto the stack for removal
match self.children.remove(&node) {
Some(children) => {
for child in children {
to_remove.push(child);
};
},
None => ()
};
}
}
#[pymethods]
impl Graph {
#[new]
fn new(graph: HashMap<HashedAny, Vec<HashedAny>>) -> Self {
let mut child_counts: HashMap<HashedAny, usize> = HashMap::new();
let mut parents: HashMap<HashedAny, Vec<HashedAny>> = HashMap::new();
let mut ready_nodes: Vec<Py<PyAny>> = Vec::new();
let mut child_count: usize;
for (node, children) in &graph {
parents.entry(node.clone()).or_insert_with(Vec::new);
child_count = (*children).len();
child_counts.insert(node.clone(), child_count);
if child_count == 0 {
ready_nodes.push(node.o.clone());
}
for child in children {
parents.entry(child.clone()).or_insert_with(Vec::new).push(node.clone());
}
}
Graph {
children: graph.clone(),
parents: parents,
child_counts: child_counts,
ready_nodes: ready_nodes,
not_done_count: graph.len(),
}
}
/// Returns string representation of the graph
fn __str__(&self) -> PyResult<String> {
Ok(format!("Graph({:?})", self.children))
}
fn __repr__(&self) -> PyResult<String> {
self.__str__()
}
/// Returns a deep copy of this graph
fn copy(&self) -> Graph {
self.clone()
}
/// Returns any nodes with no dependencies after marking `node` as done
/// # Arguments
///
/// * `node` - A node in the graph
#[args(args="*")]
fn done(&mut self, args: &PyTuple) -> PyResult<()> {
let mut node: HashedAny;
let mut v: usize;
for obj in args {
node = HashedAny::extract(obj)?;
// Check that this node is ready to be marked as done and mark it
v = *self.child_counts.get(&node).unwrap();
if v != 0 {
return Err(exceptions::PyException::new_err("Node still has children"));
}
self.not_done_count -= 1;
// Find all parents and reduce their dependency count by one,
// returning all parents w/o any further dependencies
for parent in self.parents.get(&node).unwrap() {
match self.child_counts.get_mut(parent) {
Some(v) => {
*v -= 1;
if *v == 0 {
self.ready_nodes.push(parent.o.clone());
}
},
None => return Err(exceptions::PyKeyError::new_err(format!("Parent node {:?} not found", parent)))
}
}
}
Ok(())
}
fn is_active(&self) -> bool {
self.not_done_count != 0 || !self.ready_nodes.is_empty()
}
/// Removes nodes from the graph and cleans up newly created disconnected components
/// # Arguments
///
/// * `nodes` - Nodes to be removed from the graph
fn remove(&mut self, nodes: &PyTuple) -> PyResult<()> {
let mut to_remove: Vec<HashedAny> = Vec::new();
for node in nodes {
self.remove_node(&HashedAny::extract(node)?, &mut to_remove);
}
let mut node: HashedAny;
loop {
node = match to_remove.pop() {
Some(v) => v,
None => return Ok(())
};
self.remove_node(&node, &mut to_remove);
}
}
/// Returns all nodes with no dependencies
fn get_ready(&mut self) -> Vec<Py<PyAny>> {
let ret = self.ready_nodes.clone();
self.ready_nodes.clear();
ret
}
}
#[pymodule]
fn di_lib(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Graph>()?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment