Skip to content

Instantly share code, notes, and snippets.

@matthewjberger
Last active November 3, 2023 13:38
Show Gist options
  • Save matthewjberger/b008dfe62b3356b2289199c2a8478446 to your computer and use it in GitHub Desktop.
Save matthewjberger/b008dfe62b3356b2289199c2a8478446 to your computer and use it in GitHub Desktop.
Nodegraph V2
use petgraph::graphmap::DiGraphMap;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, error::Error, fmt, hash::Hash};
#[derive(Debug)]
pub struct NodeGraphError {
details: String,
}
impl NodeGraphError {
fn new(msg: &str) -> NodeGraphError {
NodeGraphError {
details: msg.to_string(),
}
}
}
impl fmt::Display for NodeGraphError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.details)
}
}
impl Error for NodeGraphError {}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct NodeGraph<ID, K, N>
where
ID: Copy + Eq + Hash + Clone + Ord,
K: Clone + PartialEq,
N: Serialize + PartialEq,
{
graph: DiGraphMap<ID, K>,
node_data: HashMap<ID, N>,
}
impl<ID, K, N> NodeGraph<ID, K, N>
where
ID: Copy + Eq + Hash + Clone + Ord,
K: Clone + PartialEq,
N: Serialize + PartialEq,
{
pub fn new() -> Self {
NodeGraph {
graph: DiGraphMap::new(),
node_data: HashMap::new(),
}
}
pub fn add_node(&mut self, id: ID, data: N) {
self.graph.add_node(id);
self.node_data.insert(id, data);
}
pub fn remove_node(&mut self, id: ID) -> Option<(ID, N)> {
let node_removed = self.graph.remove_node(id);
let data_removed = self.node_data.remove(&id);
if node_removed {
Some((id, data_removed?))
} else {
None
}
}
pub fn add_edge(&mut self, from: ID, to: ID, value: K) -> Result<(), NodeGraphError> {
if !self.graph.contains_node(from) || !self.graph.contains_node(to) {
return Err(NodeGraphError::new("One or both nodes do not exist"));
}
self.graph.add_edge(from, to, value);
Ok(())
}
pub fn remove_edge(&mut self, from: ID, to: ID) -> Option<K> {
self.graph.remove_edge(from, to)
}
pub fn contains_node(&self, id: ID) -> bool {
self.graph.contains_node(id)
}
pub fn contains_edge(&self, from: ID, to: ID) -> bool {
self.graph.contains_edge(from, to)
}
pub fn get_node_data(&self, id: ID) -> Option<&N> {
self.node_data.get(&id)
}
pub fn update_node_data(&mut self, id: ID, data: N) -> Option<N> {
self.node_data.insert(id, data)
}
pub fn nodes(&self) -> impl Iterator<Item = &ID> {
self.node_data.keys()
}
pub fn edges(&self) -> impl Iterator<Item = (ID, ID, &K)> {
self.graph.all_edges().map(|(a, b, w)| (a, b, w))
}
pub fn neighbors(&self, id: ID) -> impl Iterator<Item = ID> + '_ {
self.graph.neighbors(id)
}
pub fn node_data_mut(&mut self, id: ID) -> Option<&mut N> {
self.node_data.get_mut(&id)
}
}
impl<ID, K, N> PartialEq for NodeGraph<ID, K, N>
where
ID: Copy + Eq + Hash + Clone + Ord,
K: Clone + PartialEq,
N: Serialize + PartialEq,
{
fn eq(&self, other: &Self) -> bool {
if self.graph.node_count() != other.graph.node_count()
|| self.graph.edge_count() != other.graph.edge_count()
{
return false;
}
// Check if the node data is equal
for (id, data) in &self.node_data {
if let Some(other_data) = other.node_data.get(id) {
if data != other_data {
return false;
}
} else {
return false;
}
}
// Check if the edge sets are equal
for edge in self.graph.all_edges() {
if let Some(other_edge_value) = other.graph.edge_weight(edge.0, edge.1) {
if edge.2 != other_edge_value {
return false;
}
} else {
return false;
}
}
true
}
}
impl<ID, K, N> Eq for NodeGraph<ID, K, N>
where
ID: Copy + Eq + Hash + Clone + Ord,
K: Clone + PartialEq,
N: Serialize + PartialEq,
{
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_new_graph_is_empty() {
let graph: NodeGraph<i32, String, String> = NodeGraph::new();
assert_eq!(graph.graph.node_count(), 0);
assert_eq!(graph.graph.edge_count(), 0);
}
#[test]
fn test_add_and_remove_node() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
let node_id = 1;
let node_data = "Node 1 data".to_string();
graph.add_node(node_id, node_data.clone());
assert!(graph.contains_node(node_id));
assert_eq!(graph.get_node_data(node_id), Some(&node_data));
let removed = graph.remove_node(node_id).unwrap();
assert_eq!(removed, (node_id, node_data));
assert!(!graph.contains_node(node_id));
}
#[test]
fn test_add_and_remove_edge() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1".to_string());
graph.add_node(2, "Node 2".to_string());
let edge_value = "connects".to_string();
graph.add_edge(1, 2, edge_value.to_string()).unwrap();
assert!(graph.contains_edge(1, 2));
let removed_edge_value = graph.remove_edge(1, 2).unwrap();
assert_eq!(removed_edge_value, edge_value);
assert!(!graph.contains_edge(1, 2));
}
#[test]
fn test_edge_cases() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1".to_string());
assert!(graph.add_edge(1, 2, "connects".to_string()).is_err()); // Edge to non-existent node
assert!(graph.remove_edge(1, 2).is_none()); // Remove non-existent edge
assert!(graph.remove_node(2).is_none()); // Remove non-existent node
}
#[test]
fn test_graph_equality() {
let mut graph1: NodeGraph<i32, String, String> = NodeGraph::new();
let mut graph2: NodeGraph<i32, String, String> = NodeGraph::new();
graph1.add_node(1, "Node 1".to_string());
graph2.add_node(1, "Node 1".to_string());
graph1.add_node(2, "Node 2".to_string());
graph2.add_node(2, "Node 2".to_string());
graph1.add_edge(1, 2, "connects".to_string()).unwrap();
graph2.add_edge(1, 2, "connects".to_string()).unwrap();
assert_eq!(graph1, graph2);
graph2.add_node(3, "Node 3".to_string());
assert_ne!(graph1, graph2);
}
#[derive(Default, Debug, Serialize, Deserialize, PartialEq)]
pub struct Position(u8, u8);
#[test]
fn test_serialization_and_deserialization() {
let mut graph: NodeGraph<i32, String, Position> = NodeGraph::new();
graph.add_node(1, Position(0, 1));
graph.add_edge(1, 1, "self-loop".to_string()).unwrap();
println!("{graph:#?}");
let serialized = serde_json::to_string(&graph).unwrap();
println!("Serialized: {serialized}");
let deserialized: NodeGraph<i32, String, Position> =
serde_json::from_str(&serialized).unwrap();
println!("Deserialized: {deserialized:#?}");
}
// New tests for the added methods
#[test]
fn test_update_node_data() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1 data".to_string());
let updated_data = "Updated Node 1 data".to_string();
let original_data = graph.update_node_data(1, updated_data.clone());
assert_eq!(original_data, Some("Node 1 data".to_string()));
assert_eq!(graph.get_node_data(1), Some(&updated_data));
}
#[test]
fn test_iterate_edges() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1".to_string());
graph.add_node(2, "Node 2".to_string());
graph.add_edge(1, 2, "connects".to_string()).unwrap();
let edges: Vec<(i32, i32, &String)> = graph.edges().collect();
assert_eq!(edges, vec![(1, 2, &"connects".to_string())]);
}
#[test]
fn test_iterate_neighbors() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1".to_string());
graph.add_node(2, "Node 2".to_string());
graph.add_node(3, "Node 3".to_string());
graph.add_edge(1, 2, "connects".to_string()).unwrap();
graph.add_edge(1, 3, "connects".to_string()).unwrap();
let neighbors: Vec<i32> = graph.neighbors(1).collect();
assert!(neighbors.contains(&2));
assert!(neighbors.contains(&3));
}
#[test]
fn test_mutate_node_data() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1 data".to_string());
if let Some(data) = graph.node_data_mut(1) {
*data = "Mutated Node 1 data".to_string();
}
assert_eq!(
graph.get_node_data(1),
Some(&"Mutated Node 1 data".to_string())
);
}
#[test]
fn test_iterate_nodes() {
let mut graph: NodeGraph<i32, String, String> = NodeGraph::new();
graph.add_node(1, "Node 1 data".to_string());
graph.add_node(2, "Node 2 data".to_string());
let mut nodes: Vec<&i32> = graph.nodes().collect();
// Sort the nodes to ensure the order
nodes.sort();
assert_eq!(nodes, vec![&1, &2]);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment