Skip to content

Instantly share code, notes, and snippets.

@prnake
Created February 24, 2024 17:04
Show Gist options
  • Save prnake/ab7fe8c62d9edf2f1da4a5ee171fc56d to your computer and use it in GitHub Desktop.
Save prnake/ab7fe8c62d9edf2f1da4a5ee171fc56d to your computer and use it in GitHub Desktop.
UnionFind
#![allow(unused)]
//! 并查集:解决节点连接/关联问题
use std::fs::File;
use std::io::{self, Read, Write};
use std::path::Path;
use std::slice;
fn load_vector<T>(path: &Path) -> io::Result<Vec<T>>
where
T: std::marker::Copy, // T 必须是可拷贝的,因为我们将处理原始内存
{
let mut file = File::open(path)?;
let mut buffer = Vec::new();
file.read_to_end(&mut buffer)?;
let (prefix, data, suffix) = unsafe { buffer.align_to::<T>() };
if !prefix.is_empty() || !suffix.is_empty() {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Data is not properly aligned"));
}
let len = data.len();
let ptr = data.as_ptr();
std::mem::forget(buffer);
Ok(unsafe { Vec::from_raw_parts(ptr as *mut T, len, len) })
}
/// 并查集结构
pub struct UnionFind {
// 存储父节点
pub parent: Vec<usize>,
// 高度
pub rank: Vec<u32>,
}
impl UnionFind {
/// 构造
pub fn new_with_size(size: usize) -> Self {
let mut res = Self {
parent: vec![0_usize; size],
rank: vec![0_u32; size],
};
for i in 0..res.parent.len() {
res.parent[i] = i;
}
return res;
}
pub fn load(path: String) -> io::Result<Self> {
let parent_path = path.clone() + ".parent.bin";
let rank_path = path.clone() + ".rank.bin";
let parent: Vec<usize> = load_vector(Path::new(&parent_path))?;
let rank: Vec<u32> = load_vector(Path::new(&rank_path))?;
if parent.len() != rank.len() {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Parent and rank data have different sizes"));
}
Ok(UnionFind { parent, rank })
}
/// 保存 UnionFind 到文件
pub fn save(&self, path: String) -> io::Result<()> {
let parent_path = path.clone() + ".parent.bin";
let rank_path = path.clone() + ".rank.bin";
let mut parent_file = File::create(parent_path)?;
let mut rank_file = File::create(rank_path)?;
parent_file.write_all(unsafe {
slice::from_raw_parts(self.parent.as_ptr() as *const u8, self.parent.len() * std::mem::size_of::<usize>())
})?;
rank_file.write_all(unsafe {
slice::from_raw_parts(self.rank.as_ptr() as *const u8, self.rank.len() * std::mem::size_of::<u32>())
})?;
Ok(())
}
/// 查询p的根
pub fn find(&mut self, p: usize) -> Result<usize, &'static str> {
if p >= self.parent.len() {
return Err("参数错误");
}
let mut c = p;
// 寻找根
while c != self.parent[c] {
// 压缩高度
self.parent[c] = self.parent[self.parent[c]];
c = self.parent[c];
}
return Ok(c);
}
/// 查看两元素是不是同一个根
pub fn is_connected(&mut self, p: usize, q: usize) -> bool {
let p_root = self.find(p).unwrap();
let q_root = self.find(q).unwrap();
return p_root == q_root;
}
/// 合并两个元素为一个根
pub fn union_elements(&mut self, p: usize, q: usize) {
let p_root = self.find(p).unwrap();
let q_root = self.find(q).unwrap();
if p_root != q_root {
if self.rank[p_root] < self.rank[q_root] {
self.parent[p_root] = self.parent[q_root];
} else if self.rank[q_root] < self.rank[p_root] {
self.parent[q_root] = self.parent[p_root];
} else {
self.parent[q_root] = self.parent[p_root];
self.rank[p_root] += 1;
}
}
}
pub fn size(&self) -> usize {
return self.parent.len();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let mut union_find = UnionFind::new_with_size(10);
union_find.union_elements(3, 5);
union_find.union_elements(2, 1);
union_find.union_elements(5, 1);
union_find.union_elements(5, 4);
assert_eq!(union_find.is_connected(4, 1), true);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment