Skip to content

Instantly share code, notes, and snippets.

@rjvdw
Last active March 26, 2019 20:24
Show Gist options
  • Save rjvdw/b22029f48803332a905acc7ecaed6a1c to your computer and use it in GitHub Desktop.
Save rjvdw/b22029f48803332a905acc7ecaed6a1c to your computer and use it in GitHub Desktop.
Pathfinding algorithms in Rust
use std::cmp::min;
use std::collections::HashSet;
use crate::graph::Graph;
use crate::node::Node;
use crate::path::Path;
pub fn brute_force(graph: &Graph, from: &Node, to: &Node) {
let mut seen = HashSet::new();
match compute(graph, from, to, &mut seen, Path { edges: vec![] }) {
Some(path) => println!("Shortest path: {}", path),
None => println!("No path found..."),
}
}
pub fn compute<'a>(
graph: &Graph<'a>,
from: &Node<'a>,
to: &Node<'a>,
seen: &mut HashSet<&Node<'a>>,
current_path: Path<'a>,
) -> Option<Path<'a>> {
println!("{} -> {}", from, to);
if from == to {
println!("Found a path: {}", current_path);
Some(current_path)
} else {
let mut shortest_path = None;
for edge in graph.get_edges(from) {
if !seen.contains(edge.n2) {
let expanded_path = current_path.expand(edge);
seen.insert(edge.n2);
let path = compute(graph, edge.n2, to, seen, expanded_path);
seen.remove(edge.n2);
shortest_path = match path {
Some(p1) => match shortest_path {
Some(p2) => Some(min(p1, p2)),
None => Some(p1),
},
None => shortest_path,
}
}
}
shortest_path
}
}
use std::collections::{HashMap, HashSet};
use crate::graph::Graph;
use crate::node::Node;
use crate::path::Path;
pub fn dijkstra(graph: &Graph, from: &Node, to: &Node) {
let mut unvisited = HashSet::new();
let mut paths = HashMap::new();
for node in graph.nodes.iter() {
unvisited.insert(*node);
}
paths.insert(from, Path { edges: vec![] });
match compute(graph, from, to, &mut paths, &mut unvisited) {
Some(path) => println!("Shortest path: {}", path),
None => println!("No path found..."),
};
}
fn compute<'a>(
graph: &Graph<'a>,
from: &Node<'a>,
to: &Node<'a>,
paths: &'a mut HashMap<&Node<'a>, Path<'a>>,
unvisited: &mut HashSet<&Node<'a>>,
) -> Option<&'a Path<'a>> {
let mut current = from;
while !unvisited.is_empty() {
for edge in graph.get_edges(current) {
let current_path = paths.get(current).unwrap();
if unvisited.contains(edge.n2) {
let distance = current_path.get_length() + edge.distance;
if !paths.contains_key(edge.n2) || paths.get(edge.n2).unwrap().get_length() > distance {
paths.insert(edge.n2, current_path.expand(edge));
}
}
}
unvisited.remove(current);
let mut next = None;
for candidate in unvisited.iter() {
if paths.contains_key(candidate) {
next = match next {
None => Some(candidate),
Some(n) =>
if paths.get(candidate).unwrap() < paths.get(n).unwrap() {
Some(candidate)
} else {
Some(n)
},
}
}
}
match next {
None => return None,
Some(n) => current = n,
}
if current == to {
return paths.get(current);
}
}
None
}
use std::fmt::{Display, Error, Formatter};
use std::hash::{Hash, Hasher};
use crate::node::Node;
#[derive(Copy, Clone)]
pub struct Edge<'a> {
pub n1: &'a Node<'a>,
pub n2: &'a Node<'a>,
pub distance: u32,
}
impl<'a> PartialEq for Edge<'a> {
fn eq(&self, other: &Self) -> bool {
self.n1.eq(other.n1) && self.n2.eq(other.n2)
}
}
impl<'a> Eq for Edge<'a> {}
impl<'a> Hash for Edge<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.n1.hash(state);
self.n2.hash(state);
}
}
impl<'a> Display for Edge<'a> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{} -> {}, (d={})", self.n1, self.n2, self.distance)
}
}
use std::collections::{HashMap, HashSet};
use std::fmt::{Display, Error, Formatter};
use std::hash::Hash;
use crate::edge::Edge;
use crate::node::Node;
type NodeSet<'a> = HashSet<&'a Node<'a>>;
type EdgeSet<'a> = HashSet<&'a Edge<'a>>;
type NodeList<'a> = Vec<Node<'a>>;
type EdgeList<'a> = Vec<Edge<'a>>;
type NodeToEdgesMap<'a> = HashMap<&'a Node<'a>, EdgeSet<'a>>;
pub struct Graph<'a> {
pub nodes: NodeSet<'a>,
pub edges: EdgeSet<'a>,
nodes_to_edges: NodeToEdgesMap<'a>,
}
impl<'a> Graph<'a> {
pub fn new(nodes: &'a NodeList<'a>, edges: &'a EdgeList<'a>) -> Graph<'a> {
let mut nodes_set = HashSet::new();
let mut edges_set = HashSet::new();
let mut nodes_to_edges = HashMap::new();
for node in nodes {
nodes_set.insert(node);
nodes_to_edges.insert(node, HashSet::new());
}
for edge in edges {
edges_set.insert(edge);
nodes_to_edges.get_mut(edge.n1).unwrap().insert(edge);
nodes_to_edges.get_mut(edge.n2).unwrap().insert(edge);
}
Graph {
nodes: nodes_set,
edges: edges_set,
nodes_to_edges,
}
}
pub fn get_edges(&self, node: &Node<'a>) -> &EdgeSet<'a> {
self.nodes_to_edges.get(node).unwrap()
}
}
impl<'a> Display for Graph<'a> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "Graph<nodes=[{}], edges=[{}]>",
hash_set_to_list(&self.nodes),
hash_set_to_list(&self.edges))
}
}
fn hash_set_to_list<T: Display + Eq + Hash>(set: &HashSet<T>) -> String {
let mut els = set.iter()
.map(|e| e.to_string())
.collect::<Vec<String>>();
els.sort();
els.join(", ")
}
use std::env;
use crate::brute_force::brute_force;
use crate::dijkstra::dijkstra;
use crate::edge::Edge;
use crate::graph::Graph;
use crate::node::Node;
mod dijkstra;
mod node;
mod edge;
mod graph;
mod brute_force;
mod path;
fn main() {
let nodes = vec![
Node { name: "A" },
Node { name: "B" },
Node { name: "C" },
Node { name: "D" },
Node { name: "E" },
Node { name: "F" },
Node { name: "G" },
Node { name: "H" },
Node { name: "I" },
Node { name: "J" },
Node { name: "K" },
];
let edges = vec![
Edge { n1: &nodes[0], n2: &nodes[1], distance: 5 },
Edge { n1: &nodes[0], n2: &nodes[2], distance: 4 },
Edge { n1: &nodes[1], n2: &nodes[6], distance: 7 },
Edge { n1: &nodes[1], n2: &nodes[5], distance: 6 },
Edge { n1: &nodes[1], n2: &nodes[3], distance: 5 },
Edge { n1: &nodes[2], n2: &nodes[3], distance: 2 },
Edge { n1: &nodes[2], n2: &nodes[4], distance: 5 },
Edge { n1: &nodes[3], n2: &nodes[5], distance: 4 },
Edge { n1: &nodes[4], n2: &nodes[5], distance: 4 },
Edge { n1: &nodes[4], n2: &nodes[8], distance: 4 },
Edge { n1: &nodes[5], n2: &nodes[7], distance: 2 },
Edge { n1: &nodes[5], n2: &nodes[9], distance: 4 },
Edge { n1: &nodes[6], n2: &nodes[7], distance: 2 },
Edge { n1: &nodes[7], n2: &nodes[10], distance: 4 },
Edge { n1: &nodes[8], n2: &nodes[9], distance: 3 },
Edge { n1: &nodes[9], n2: &nodes[10], distance: 4 },
];
let graph = Graph::new(&nodes, &edges);
let args: Vec<String> = env::args().collect();
let algorithm = if args.len() < 2 {
"brute-force"
} else {
&args[1].as_ref()
};
match algorithm {
"brute-force" => brute_force(&graph, &nodes[0], &nodes[10]),
"dijkstra" => dijkstra(&graph, &nodes[0], &nodes[10]),
_ => println!("Unknown algorithm: '{}'", algorithm),
}
}
use std::fmt::{Display, Error, Formatter};
use std::hash::{Hash, Hasher};
#[derive(Copy, Clone)]
pub struct Node<'a> {
pub name: &'a str,
}
impl<'a> PartialEq for Node<'a> {
fn eq(&self, other: &Self) -> bool {
self.name.eq(other.name)
}
}
impl<'a> Eq for Node<'a> {}
impl<'a> Hash for Node<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.name.hash(state);
}
}
impl<'a> Display for Node<'a> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(f, "{}", self.name)
}
}
use std::cmp::Ordering;
use std::fmt::{Display, Error, Formatter};
use crate::edge::Edge;
pub struct Path<'a> {
pub edges: Vec<&'a Edge<'a>>,
}
impl<'a> Path<'a> {
pub fn get_length(&self) -> u32 {
self.edges.iter()
.map(|e| e.distance)
.sum()
}
pub fn expand(&self, edge: &'a Edge<'a>) -> Path<'a> {
let mut edges = Vec::new();
self.edges.iter()
.for_each(|e| edges.push(*e));
edges.push(edge);
Path { edges }
}
fn get_node_names(&self) -> Vec<String> {
let mut nodes = Vec::new();
if !self.edges.is_empty() {
nodes.push(self.edges[0].n1.to_string());
}
self.edges.iter()
.for_each(|e| nodes.push(e.n2.to_string()));
nodes
}
}
impl<'a> PartialEq for Path<'a> {
fn eq(&self, other: &Self) -> bool {
self.get_length() == other.get_length()
}
}
impl<'a> Eq for Path<'a> {}
impl<'a> PartialOrd for Path<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl<'a> Ord for Path<'a> {
fn cmp(&self, other: &Self) -> Ordering {
self.get_length().cmp(&other.get_length())
}
}
impl<'a> Display for Path<'a> {
fn fmt(&self, f: &mut Formatter) -> Result<(), Error> {
write!(
f,
"{} (d_total={})",
self.get_node_names().join(" -> "),
self.get_length()
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment