Skip to content

Instantly share code, notes, and snippets.

@adriangb
Last active November 24, 2021 08:59
Show Gist options
  • Save adriangb/9d4561fa9ac04eb59dafc324b0550a60 to your computer and use it in GitHub Desktop.
Save adriangb/9d4561fa9ac04eb59dafc324b0550a60 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 9,
"id": "bf58a19a",
"metadata": {},
"outputs": [],
"source": [
"from typing import *\n",
"\n",
"import di_lib\n",
"import graphlib"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c43d18d4",
"metadata": {},
"outputs": [],
"source": [
"from random import Random\n",
"import networkx as nx\n",
"\n",
"\n",
"random = Random(42)\n",
"\n",
"\n",
"def get_linear_graph(n: int) -> Dict[int, List[int]]:\n",
" graph: Dict[int, List[int]] = {}\n",
" i = 0\n",
" for i in range(n):\n",
" graph[i] = [i+1] # node 0 depends on node 1, etc.\n",
" graph[i+1] = [] # node n has no dependencies\n",
" return graph\n",
"\n",
"def get_random_graph(n: int) -> Dict[int, List[int]]:\n",
" G = nx.gnp_random_graph(n, 0.5, directed=True)\n",
" DAG = nx.DiGraph([(u,v,{'weight': random.randint(-10, 10)}) for (u,v) in G.edges() if u<v])\n",
" return nx.convert.to_dict_of_lists(DAG)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fb648e52",
"metadata": {},
"outputs": [],
"source": [
"def run(t: Union[graphlib.TopologicalSorter, di_lib.Graph]) -> None:\n",
" to_remove = t.get_ready()\n",
" while t.is_active():\n",
" t.done(*to_remove)\n",
" to_remove = t.get_ready()\n",
"\n",
"\n",
"def setup_rust(n: int, graph_gen: Callable[[int], Dict[int, List[int]]]) -> di_lib.Graph:\n",
" return di_lib.Graph(graph_gen(n))\n",
"\n",
"\n",
"def setup_python(n: int, graph_gen: Callable[[int], Dict[int, List[int]]]) -> graphlib.TopologicalSorter:\n",
" t = graphlib.TopologicalSorter(graph_gen(n))\n",
" t.prepare()\n",
" return t"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3191f318",
"metadata": {},
"outputs": [],
"source": [
"from matplotlib import pyplot as plt\n",
"import timeit\n",
"\n",
"lower = 0\n",
"upper = 100_000\n",
"length = 50\n",
"loops = 100\n",
"\n",
"glbls = {\n",
" \"setup_python\": setup_python,\n",
" \"setup_rust\": setup_rust,\n",
" \"run\": run,\n",
" \"get_linear_graph\": get_linear_graph,\n",
" \"get_random_graph\": get_random_graph\n",
"}\n",
"\n",
"\n",
"\n",
"def plot(upper: int, samples: int, graph_factory: str) -> None:\n",
" x = [round(lower + x*(upper-lower)/samples) for x in range(samples)]\n",
" y_python: List[float] = []\n",
" y_rust: List[float] = []\n",
"\n",
" for n in x:\n",
" # Time Python and get the # of loops and match that\n",
" timerpy = timeit.Timer(stmt=\"run(t)\", setup=f\"t = setup_python({n}, {graph_factory})\", globals=glbls)\n",
" pytime = timerpy.timeit(loops)\n",
" y_python.append(pytime)\n",
" timerust = timeit.Timer(stmt=\"run(t)\", setup=f\"t = setup_rust({n}, {graph_factory})\", globals=glbls)\n",
" rustime = timerust.timeit(loops)\n",
" y_rust.append(rustime)\n",
" plt.plot(x, y_python, label=\"python\")\n",
" plt.plot(x, y_rust, label=\"rust\")\n",
" plt.legend(loc=\"upper left\")\n",
" plt.xlabel(\"V (number of vertices)\")\n",
" plt.ylabel(\"Execution time (s)\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "fcafdd12",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot(upper=100_000, samples=50, graph_factory=\"get_linear_graph\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "11ba828c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot(upper=1_000, samples=50, graph_factory=\"get_random_graph\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "58237ce0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(8.739999902900309e-05,\n",
" 5.859998054802418e-06,\n",
" 2.9899980290792883e-06,\n",
" 5.649999366141856e-06)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import statistics\n",
"\n",
"py_get_ready: List[float] = []\n",
"py_done: List[float] = []\n",
"\n",
"rust_get_ready: List[float] = []\n",
"rust_done: List[float] = []\n",
"\n",
"for _ in range(10):\n",
" t = setup_python(1_000, get_linear_graph)\n",
" start = timeit.default_timer()\n",
" ready = t.get_ready()\n",
" py_get_ready.append(timeit.default_timer()-start)\n",
" start = timeit.default_timer()\n",
" t.done(*ready)\n",
" py_done.append(timeit.default_timer()-start)\n",
"\n",
" t = setup_rust(1_000, get_linear_graph)\n",
" start = timeit.default_timer()\n",
" ready = t.get_ready()\n",
" rust_get_ready.append(timeit.default_timer()-start)\n",
" start = timeit.default_timer()\n",
" t.done(*ready)\n",
" rust_done.append(timeit.default_timer()-start)\n",
"\n",
"statistics.mean(py_get_ready),statistics.mean(py_done),statistics.mean(rust_get_ready),statistics.mean(rust_done)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "bcc43114",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(1.4799999189563095e-05,\n",
" 0.00011145999887958169,\n",
" 8.730002446100115e-06,\n",
" 6.569000106537715e-05)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import statistics\n",
"\n",
"py_get_ready: List[float] = []\n",
"py_done: List[float] = []\n",
"\n",
"rust_get_ready: List[float] = []\n",
"rust_done: List[float] = []\n",
"\n",
"for _ in range(10):\n",
" t = setup_python(1_000, get_random_graph)\n",
" start = timeit.default_timer()\n",
" ready = t.get_ready()\n",
" py_get_ready.append(timeit.default_timer()-start)\n",
" start = timeit.default_timer()\n",
" t.done(*ready)\n",
" py_done.append(timeit.default_timer()-start)\n",
"\n",
" t = setup_rust(1_000, get_random_graph)\n",
" start = timeit.default_timer()\n",
" ready = t.get_ready()\n",
" rust_get_ready.append(timeit.default_timer()-start)\n",
" start = timeit.default_timer()\n",
" t.done(*ready)\n",
" rust_done.append(timeit.default_timer()-start)\n",
"\n",
"statistics.mean(py_get_ready),statistics.mean(py_done),statistics.mean(rust_get_ready),statistics.mean(rust_done)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aacf1192",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
use std::cmp;
use std::hash;
use std::fmt;
use pyo3::basic::CompareOp;
use pyo3::prelude::*;
// We can't put a Py<PyAny> directly into a HashMap key
// So to be able to hold references to arbitrary Python objects in HashMap as keys
// we wrap them in a struct that gets the hash() when it receives the object from Python
// and then just echoes back that hash when called Rust needs to hash it
#[derive(Clone)]
pub struct HashedAny {
pub o: Py<PyAny>,
pub hash: isize,
}
// Use the result of calling repr() on the Python object as the debug string value
impl fmt::Debug for HashedAny {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Python::with_gil(|py| -> PyResult<fmt::Result> {
let obj = self.o.as_ref(py);
let pystr = obj.repr()?;
let ruststr = pystr.to_str()?;
Ok(write!(f, "{}", ruststr))
}).unwrap()
}
}
impl <'source>FromPyObject<'source> for HashedAny
{
fn extract(ob: &'source PyAny) -> PyResult<Self> {
Ok(HashedAny{ o: ob.into(), hash: ob.hash()? })
}
}
impl hash::Hash for HashedAny {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.hash.hash(state)
}
}
impl cmp::PartialEq for HashedAny {
fn eq(&self, other: &Self) -> bool {
Python::with_gil(|py| -> PyResult<bool> {
let this_ref = self.o.as_ref(py);
let other_ref = other.o.as_ref(py);
if this_ref.eq(other_ref) {
Ok(true)
}
else {
Ok(this_ref.rich_compare(other_ref, CompareOp::Eq)?.is_true()?)
}
}).unwrap()
}
}
impl cmp::Eq for HashedAny {}
use std::collections::HashMap;
use pyo3::prelude::*;
use pyo3::exceptions;
use pyo3::types::PyTuple;
mod hashedany;
use crate::hashedany::HashedAny;
#[pyclass(module="di_lib",subclass)]
#[derive(Debug,Clone)]
struct Graph {
children: HashMap<HashedAny, Vec<HashedAny>>,
parents: HashMap<HashedAny, Vec<HashedAny>>,
child_counts: HashMap<HashedAny, usize>,
ready_nodes: Vec<Py<PyAny>>,
not_done_count: usize,
}
impl Graph {
fn remove_node(&mut self, node: &HashedAny, to_remove: &mut Vec<HashedAny>) -> () {
match self.child_counts.remove(&node) {
Some(_) => (),
// This node was already removed
// This happens if parents and children are passed in the nodes argument
None => return,
}
// Find all parents and reduce their dependency count by one
match self.parents.remove(&node) {
Some(parents) => {
for parent in parents {
match self.child_counts.get_mut(&parent) {
Some(v) => {
*v -= 1;
},
// This node was already removed
// This happens if parents and children are passed in the nodes argument
None => continue,
}
}
},
// this node was already removed
None => return,
};
// Push all children onto the stack for removal
match self.children.remove(&node) {
Some(children) => {
for child in children {
to_remove.push(child);
};
},
None => ()
};
}
}
#[pymethods]
impl Graph {
#[new]
fn new(graph: HashMap<HashedAny, Vec<HashedAny>>) -> Self {
let mut child_counts: HashMap<HashedAny, usize> = HashMap::new();
let mut parents: HashMap<HashedAny, Vec<HashedAny>> = HashMap::new();
let mut ready_nodes: Vec<Py<PyAny>> = Vec::new();
let mut child_count: usize;
for (node, children) in &graph {
parents.entry(node.clone()).or_insert_with(Vec::new);
child_count = (*children).len();
child_counts.insert(node.clone(), child_count);
if child_count == 0 {
ready_nodes.push(node.o.clone());
}
for child in children {
parents.entry(child.clone()).or_insert_with(Vec::new).push(node.clone());
}
}
Graph {
children: graph.clone(),
parents: parents,
child_counts: child_counts,
ready_nodes: ready_nodes,
not_done_count: graph.len(),
}
}
/// Returns string representation of the graph
fn __str__(&self) -> PyResult<String> {
Ok(format!("Graph({:?})", self.children))
}
fn __repr__(&self) -> PyResult<String> {
self.__str__()
}
/// Returns a deep copy of this graph
fn copy(&self) -> Graph {
self.clone()
}
/// Returns any nodes with no dependencies after marking `node` as done
/// # Arguments
///
/// * `node` - A node in the graph
#[args(args="*")]
fn done(&mut self, args: &PyTuple) -> PyResult<()> {
let mut node: HashedAny;
let mut v: usize;
for obj in args {
node = HashedAny::extract(obj)?;
// Check that this node is ready to be marked as done and mark it
v = *self.child_counts.get(&node).unwrap();
if v != 0 {
return Err(exceptions::PyException::new_err("Node still has children"));
}
self.not_done_count -= 1;
// Find all parents and reduce their dependency count by one,
// returning all parents w/o any further dependencies
for parent in self.parents.get(&node).unwrap() {
match self.child_counts.get_mut(parent) {
Some(v) => {
*v -= 1;
if *v == 0 {
self.ready_nodes.push(parent.o.clone());
}
},
None => return Err(exceptions::PyKeyError::new_err(format!("Parent node {:?} not found", parent)))
}
}
}
Ok(())
}
fn is_active(&self) -> bool {
self.not_done_count != 0 || !self.ready_nodes.is_empty()
}
/// Removes nodes from the graph and cleans up newly created disconnected components
/// # Arguments
///
/// * `nodes` - Nodes to be removed from the graph
fn remove(&mut self, nodes: &PyTuple) -> PyResult<()> {
let mut to_remove: Vec<HashedAny> = Vec::new();
for node in nodes {
self.remove_node(&HashedAny::extract(node)?, &mut to_remove);
}
let mut node: HashedAny;
loop {
node = match to_remove.pop() {
Some(v) => v,
None => return Ok(())
};
self.remove_node(&node, &mut to_remove);
}
}
/// Returns all nodes with no dependencies
fn get_ready(&mut self) -> Vec<Py<PyAny>> {
let ret = self.ready_nodes.clone();
self.ready_nodes.clear();
ret
}
}
#[pymodule]
fn di_lib(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<Graph>()?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment