Last active
November 3, 2023 13:38
-
-
Save matthewjberger/b008dfe62b3356b2289199c2a8478446 to your computer and use it in GitHub Desktop.
Nodegraph V2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use petgraph::graphmap::DiGraphMap; | |
use serde::{Deserialize, Serialize}; | |
use std::{collections::HashMap, error::Error, fmt, hash::Hash}; | |
#[derive(Debug)] | |
pub struct NodeGraphError { | |
details: String, | |
} | |
impl NodeGraphError { | |
fn new(msg: &str) -> NodeGraphError { | |
NodeGraphError { | |
details: msg.to_string(), | |
} | |
} | |
} | |
impl fmt::Display for NodeGraphError { | |
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
write!(f, "{}", self.details) | |
} | |
} | |
impl Error for NodeGraphError {} | |
#[derive(Clone, Debug, Serialize, Deserialize)] | |
pub struct NodeGraph<ID, K, N> | |
where | |
ID: Copy + Eq + Hash + Clone + Ord, | |
K: Clone + PartialEq, | |
N: Serialize + PartialEq, | |
{ | |
graph: DiGraphMap<ID, K>, | |
node_data: HashMap<ID, N>, | |
} | |
impl<ID, K, N> NodeGraph<ID, K, N> | |
where | |
ID: Copy + Eq + Hash + Clone + Ord, | |
K: Clone + PartialEq, | |
N: Serialize + PartialEq, | |
{ | |
pub fn new() -> Self { | |
NodeGraph { | |
graph: DiGraphMap::new(), | |
node_data: HashMap::new(), | |
} | |
} | |
pub fn add_node(&mut self, id: ID, data: N) { | |
self.graph.add_node(id); | |
self.node_data.insert(id, data); | |
} | |
pub fn remove_node(&mut self, id: ID) -> Option<(ID, N)> { | |
let node_removed = self.graph.remove_node(id); | |
let data_removed = self.node_data.remove(&id); | |
if node_removed { | |
Some((id, data_removed?)) | |
} else { | |
None | |
} | |
} | |
pub fn add_edge(&mut self, from: ID, to: ID, value: K) -> Result<(), NodeGraphError> { | |
if !self.graph.contains_node(from) || !self.graph.contains_node(to) { | |
return Err(NodeGraphError::new("One or both nodes do not exist")); | |
} | |
self.graph.add_edge(from, to, value); | |
Ok(()) | |
} | |
pub fn remove_edge(&mut self, from: ID, to: ID) -> Option<K> { | |
self.graph.remove_edge(from, to) | |
} | |
pub fn contains_node(&self, id: ID) -> bool { | |
self.graph.contains_node(id) | |
} | |
pub fn contains_edge(&self, from: ID, to: ID) -> bool { | |
self.graph.contains_edge(from, to) | |
} | |
pub fn get_node_data(&self, id: ID) -> Option<&N> { | |
self.node_data.get(&id) | |
} | |
pub fn update_node_data(&mut self, id: ID, data: N) -> Option<N> { | |
self.node_data.insert(id, data) | |
} | |
pub fn nodes(&self) -> impl Iterator<Item = &ID> { | |
self.node_data.keys() | |
} | |
pub fn edges(&self) -> impl Iterator<Item = (ID, ID, &K)> { | |
self.graph.all_edges().map(|(a, b, w)| (a, b, w)) | |
} | |
pub fn neighbors(&self, id: ID) -> impl Iterator<Item = ID> + '_ { | |
self.graph.neighbors(id) | |
} | |
pub fn node_data_mut(&mut self, id: ID) -> Option<&mut N> { | |
self.node_data.get_mut(&id) | |
} | |
} | |
impl<ID, K, N> PartialEq for NodeGraph<ID, K, N> | |
where | |
ID: Copy + Eq + Hash + Clone + Ord, | |
K: Clone + PartialEq, | |
N: Serialize + PartialEq, | |
{ | |
fn eq(&self, other: &Self) -> bool { | |
if self.graph.node_count() != other.graph.node_count() | |
|| self.graph.edge_count() != other.graph.edge_count() | |
{ | |
return false; | |
} | |
// Check if the node data is equal | |
for (id, data) in &self.node_data { | |
if let Some(other_data) = other.node_data.get(id) { | |
if data != other_data { | |
return false; | |
} | |
} else { | |
return false; | |
} | |
} | |
// Check if the edge sets are equal | |
for edge in self.graph.all_edges() { | |
if let Some(other_edge_value) = other.graph.edge_weight(edge.0, edge.1) { | |
if edge.2 != other_edge_value { | |
return false; | |
} | |
} else { | |
return false; | |
} | |
} | |
true | |
} | |
} | |
impl<ID, K, N> Eq for NodeGraph<ID, K, N> | |
where | |
ID: Copy + Eq + Hash + Clone + Ord, | |
K: Clone + PartialEq, | |
N: Serialize + PartialEq, | |
{ | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
use serde_json; | |
#[test] | |
fn test_new_graph_is_empty() { | |
let graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
assert_eq!(graph.graph.node_count(), 0); | |
assert_eq!(graph.graph.edge_count(), 0); | |
} | |
#[test] | |
fn test_add_and_remove_node() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
let node_id = 1; | |
let node_data = "Node 1 data".to_string(); | |
graph.add_node(node_id, node_data.clone()); | |
assert!(graph.contains_node(node_id)); | |
assert_eq!(graph.get_node_data(node_id), Some(&node_data)); | |
let removed = graph.remove_node(node_id).unwrap(); | |
assert_eq!(removed, (node_id, node_data)); | |
assert!(!graph.contains_node(node_id)); | |
} | |
#[test] | |
fn test_add_and_remove_edge() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1".to_string()); | |
graph.add_node(2, "Node 2".to_string()); | |
let edge_value = "connects".to_string(); | |
graph.add_edge(1, 2, edge_value.to_string()).unwrap(); | |
assert!(graph.contains_edge(1, 2)); | |
let removed_edge_value = graph.remove_edge(1, 2).unwrap(); | |
assert_eq!(removed_edge_value, edge_value); | |
assert!(!graph.contains_edge(1, 2)); | |
} | |
#[test] | |
fn test_edge_cases() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1".to_string()); | |
assert!(graph.add_edge(1, 2, "connects".to_string()).is_err()); // Edge to non-existent node | |
assert!(graph.remove_edge(1, 2).is_none()); // Remove non-existent edge | |
assert!(graph.remove_node(2).is_none()); // Remove non-existent node | |
} | |
#[test] | |
fn test_graph_equality() { | |
let mut graph1: NodeGraph<i32, String, String> = NodeGraph::new(); | |
let mut graph2: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph1.add_node(1, "Node 1".to_string()); | |
graph2.add_node(1, "Node 1".to_string()); | |
graph1.add_node(2, "Node 2".to_string()); | |
graph2.add_node(2, "Node 2".to_string()); | |
graph1.add_edge(1, 2, "connects".to_string()).unwrap(); | |
graph2.add_edge(1, 2, "connects".to_string()).unwrap(); | |
assert_eq!(graph1, graph2); | |
graph2.add_node(3, "Node 3".to_string()); | |
assert_ne!(graph1, graph2); | |
} | |
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)] | |
pub struct Position(u8, u8); | |
#[test] | |
fn test_serialization_and_deserialization() { | |
let mut graph: NodeGraph<i32, String, Position> = NodeGraph::new(); | |
graph.add_node(1, Position(0, 1)); | |
graph.add_edge(1, 1, "self-loop".to_string()).unwrap(); | |
println!("{graph:#?}"); | |
let serialized = serde_json::to_string(&graph).unwrap(); | |
println!("Serialized: {serialized}"); | |
let deserialized: NodeGraph<i32, String, Position> = | |
serde_json::from_str(&serialized).unwrap(); | |
println!("Deserialized: {deserialized:#?}"); | |
} | |
// New tests for the added methods | |
#[test] | |
fn test_update_node_data() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1 data".to_string()); | |
let updated_data = "Updated Node 1 data".to_string(); | |
let original_data = graph.update_node_data(1, updated_data.clone()); | |
assert_eq!(original_data, Some("Node 1 data".to_string())); | |
assert_eq!(graph.get_node_data(1), Some(&updated_data)); | |
} | |
#[test] | |
fn test_iterate_edges() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1".to_string()); | |
graph.add_node(2, "Node 2".to_string()); | |
graph.add_edge(1, 2, "connects".to_string()).unwrap(); | |
let edges: Vec<(i32, i32, &String)> = graph.edges().collect(); | |
assert_eq!(edges, vec![(1, 2, &"connects".to_string())]); | |
} | |
#[test] | |
fn test_iterate_neighbors() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1".to_string()); | |
graph.add_node(2, "Node 2".to_string()); | |
graph.add_node(3, "Node 3".to_string()); | |
graph.add_edge(1, 2, "connects".to_string()).unwrap(); | |
graph.add_edge(1, 3, "connects".to_string()).unwrap(); | |
let neighbors: Vec<i32> = graph.neighbors(1).collect(); | |
assert!(neighbors.contains(&2)); | |
assert!(neighbors.contains(&3)); | |
} | |
#[test] | |
fn test_mutate_node_data() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1 data".to_string()); | |
if let Some(data) = graph.node_data_mut(1) { | |
*data = "Mutated Node 1 data".to_string(); | |
} | |
assert_eq!( | |
graph.get_node_data(1), | |
Some(&"Mutated Node 1 data".to_string()) | |
); | |
} | |
#[test] | |
fn test_iterate_nodes() { | |
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new(); | |
graph.add_node(1, "Node 1 data".to_string()); | |
graph.add_node(2, "Node 2 data".to_string()); | |
let mut nodes: Vec<&i32> = graph.nodes().collect(); | |
// Sort the nodes to ensure the order | |
nodes.sort(); | |
assert_eq!(nodes, vec![&1, &2]); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment