Skip to content

Instantly share code, notes, and snippets.

@Eleobert
Created October 21, 2021 19:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Eleobert/6e3dd16f64f63cd927aa7c13da238922 to your computer and use it in GitHub Desktop.
Save Eleobert/6e3dd16f64f63cd927aa7c13da238922 to your computer and use it in GitHub Desktop.
Aglomerative clustering
#include <armadillo>
#include <vector>
auto get_cluster_sim(const arma::mat& sim, const arma::uvec& cluster_a, const arma::uvec& cluster_b)
{
arma::uvec combined_clusters = arma::join_cols(cluster_a, cluster_b);
arma::mat combined_sims = sim(combined_clusters, combined_clusters);
auto exemplar = combined_clusters(arma::mean(combined_sims).index_max());
arma::vec exe_sims = sim.col(exemplar);
return (arma::mean(exe_sims(cluster_a)) + arma::mean(exe_sims(cluster_b))) / 2.0;
}
auto get_clusters_sim(const arma::mat& sim, const std::vector<arma::uvec>& clusters)
{
arma::mat res(clusters.size(), clusters.size());
res.fill(arma::datum::nan);
for(size_t i = 0; i < clusters.size(); i++)
{
res(i, i) = -arma::datum::inf;
for(size_t j = i + 1; j < clusters.size(); j++)
{
res(i, j) = get_cluster_sim(sim, clusters[i], clusters[j]);
res(j, i) = res(i, j);
}
}
return res;
}
auto get_index_max(const arma::mat& mat)
{
auto index = arma::index_max(mat.as_col());
return std::make_pair(index % mat.n_cols, index / mat.n_cols);
}
auto remove_clusters(std::vector<arma::uvec>& clusters, size_t idx1, size_t idx2)
{
auto [idx_min, idx_max] = std::minmax(idx1, idx2);
clusters.erase(clusters.begin() + idx_max);
clusters.erase(clusters.begin() + idx_min);
}
auto remove_cluster_similarities(arma::mat& clusters_sim, size_t idx1, size_t idx2)
{
auto [idx_min, idx_max] = std::minmax(idx1, idx2);
clusters_sim.shed_col(idx_max);
clusters_sim.shed_col(idx_min);
clusters_sim.shed_row(idx_max);
clusters_sim.shed_row(idx_min);
}
auto update_clusters_sim(const arma::mat& sim, arma::mat& clusters_sim, std::vector<arma::uvec>& clusters,
const arma::uvec& new_cluster)
{
arma::vec dumb_vec(clusters_sim.n_cols);
dumb_vec.fill(arma::datum::nan);
clusters_sim = arma::join_cols(clusters_sim, dumb_vec.t());
dumb_vec = arma::vec(clusters_sim.n_cols + 1);
dumb_vec.fill(arma::datum::nan);
clusters_sim = arma::join_rows(clusters_sim, dumb_vec);
auto j = clusters_sim.n_cols - 1;
for(size_t i = 0; i < clusters_sim.n_rows - 1; i++)
{
clusters_sim(i, j) = get_cluster_sim(sim, clusters[i], new_cluster);
clusters_sim(j, i) = clusters_sim(i, j);
}
clusters_sim(j, j) = -arma::datum::inf;
}
auto agcluster(const arma::mat& sim, std::vector<arma::uvec> clusters, float cut_height)
{
arma::mat clusters_sim = get_clusters_sim(sim, clusters);
while(true)
{
auto [i_max, j_max] = get_index_max(clusters_sim);
auto height = clusters_sim(i_max, j_max);
if(height < cut_height)
{
return clusters;
}
arma::uvec new_cluster = arma::join_cols(clusters[i_max], clusters[j_max]);
// remove the old clusters and insert the new one
remove_clusters(clusters, i_max, j_max);
clusters.emplace_back(new_cluster);
if(clusters.size() == 1)
{
return clusters;
}
remove_cluster_similarities(clusters_sim, i_max, j_max);
// add the new cluster similarities
update_clusters_sim(sim, clusters_sim, clusters, new_cluster);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment