Created
March 14, 2024 12:56
-
-
Save guimarqu/2b4386c65f0270c51de93244e5ec50ef to your computer and use it in GitHub Desktop.
Agglomerative clustering with precomputed metric, complete linkage, and maximum distance threshold
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Adapted from https://docs.rs/linfa-hierarchical/latest/src/linfa_hierarchical/lib.rs.html | |
pub fn agglomerative_clustering( | |
dist_matrix: &mut [f64], | |
num_observations: usize, | |
max_dist: f64, | |
) -> anyhow::Result<Vec<i32>> { | |
let res = linkage(dist_matrix, num_observations, kodama::Method::Complete); | |
let mut clusters = (0..num_observations) | |
.map(|x| (x, vec![x])) | |
.collect::<HashMap<_, _>>(); | |
// counter for new clusters, which are formed as unions of previous ones | |
let mut ct = num_observations; | |
for step in res.steps() { | |
if step.dissimilarity > max_dist { | |
break; | |
} | |
// combine ids from both clusters | |
let mut ids = Vec::with_capacity(2); | |
let mut cl = clusters.remove(&step.cluster1).unwrap(); | |
ids.append(&mut cl); | |
let mut cl = clusters.remove(&step.cluster2).unwrap(); | |
ids.append(&mut cl); | |
// insert into hashmap and increase counter | |
clusters.insert(ct, ids); | |
ct += 1; | |
} | |
// Optional operation to make your result deterministic | |
let mut ord_cluster_sets = clusters.values().into_iter().collect::<Vec<_>>(); | |
ord_cluster_sets.sort_by(|&a, &b| a.iter().min().cmp(&b.iter().min())); | |
// flatten resulting clusters and reverse index | |
let mut cluster_sets = vec![0i32; num_observations]; | |
for (i, ids) in ord_cluster_sets.into_iter().enumerate() { | |
for &id in ids { | |
cluster_sets[id] = i as i32; | |
} | |
} | |
Ok(cluster_sets) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment