Skip to content

Instantly share code, notes, and snippets.

@shikharbhardwaj
Last active June 12, 2017 17:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shikharbhardwaj/89bf1f1e58b3730b4bfc00bb6373e222 to your computer and use it in GitHub Desktop.
Save shikharbhardwaj/89bf1f1e58b3730b4bfc00bb6373e222 to your computer and use it in GitHub Desktop.
Work on the parallel KMeans implentation in mlpack
#include <limits>
#include <mlpack/core.hpp>
#include <mlpack/methods/kmeans/parallel_naive_kmeans.hpp>
arma::mat
kMeansData(" 0.0 0.0; 0.3 0.4; 0.1 0.0; 0.1 0.3;"
" -0.2 -0.2; -0.1 0.3; -0.4 0.1; 0.2 -0.1; 0.3 0.0;"
"-0.3 -0.3; 0.1 -0.1; 0.2 -0.3; -0.3 0.2; 10.0 10.0;"
"10.1 9.9; 9.9 10.0; 10.2 9.7; 10.2 9.8; 9.7 10.3;"
"9.9 10.1;-10.0 5.0; -9.8 5.1; -9.9 4.9;-10.0 4.9;"
"-10.2 5.2;-10.1 5.1;-10.3 5.3;-10.0 4.8; -9.6 5.0;"
"-9.8 5.1;");
using namespace std;
using namespace mlpack;
using namespace mlpack::metric;
using namespace mlpack::kmeans;
double cost(const arma::mat &data, const arma::mat &centroids) {
double cost = 0.f;
#pragma omp parallel for
for (size_t i = 0; i < data.n_cols; ++i) {
double minCost = std::numeric_limits<double>::infinity();
for (size_t j = 0; j < centroids.n_cols; ++j) {
minCost = min(minCost, EuclideanDistance().Evaluate(
data.col(i), centroids.col(j)));
}
cost += minCost;
}
return cost;
}
int main() {
auto metric = EuclideanDistance();
arma::mat train_data;
mlpack::data::Load("adult_data_rev.csv", train_data);
// Extract labels
arma::Row<size_t> labels(train_data.n_cols);
for (size_t i = 0; i < train_data.n_cols; ++i) {
labels[i] = train_data(train_data.n_rows - 1, i);
}
// Remove the labels from the data
train_data.shed_row(train_data.n_rows - 1);
ParallelNaiveKMeans<EuclideanDistance, arma::mat> kmeans(train_data,
metric);
arma::mat curCentroids(
"88.22962961016027 2.044024102027296 1222550.8100276222 "
"12.903700748511609 4.076422813795493 1.03184756627729 "
"1.4002157705668625 0.8460159809598522 0.7246672042644327 "
"0.761125417886293 39803.65091923931 383.5221088662193 "
"81.72930049509185 6.465427904613903; "
"78.8079205643903 6.214012816167525 352121.1987978431 "
"7.719785410736369 5.422774896451617 4.38621601591405 "
"11.542851182937957 3.166299079076285 1.7278007414295153 "
"0.2328544904113854 34425.9377703859 3928.37645603819 "
"77.38844147164069 19.21207145800073;");
arma::inplace_trans(curCentroids);
cout << "\n";
size_t maxIters = 1e3;
mlpack::Timer::Start("Iteration");
for(size_t i = 0; i < maxIters; ++i){
//cout << "\rCost : " << cost(train_data, curCentroids);
arma::mat nextCentroids;
arma::Col<size_t> assigns;
kmeans.Iterate(curCentroids, nextCentroids, assigns);
curCentroids = nextCentroids;
}
mlpack::Timer::Stop("Iteration");
cout << "Time for iterations : " << Timer::Get("Iteration").count() <<
"us\n";
}
#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt
def showKMeans(data, clusters):
plt.clf()
plt.scatter(data[:, 0], data[:, 1])
plt.scatter(clusters[:, 0], clusters[:, 1])
plt.show()
def cost(data, clusters):
cost = 0
for i in range(data.shape[0]):
minCost = float('Inf')
for j in range(clusters.shape[0]):
minCost = min(minCost, np.linalg.norm(data[i]-clusters[j]))
cost += minCost
return cost
def iterate(data, clusters):
newClusters = np.array([[0.0] * clusters.shape[1]] * clusters.shape[0])
clusterCounts = np.array([0] * clusters.shape[0])
for i in range(data.shape[0]):
minDist = float('Inf')
assignedCluster = float('NaN')
for j in range(clusters.shape[0]):
curDist = np.linalg.norm(data[i] - clusters[j])
if curDist < minDist:
minDist = curDist
assignedCluster = j
newClusters[assignedCluster] += data[i]
clusterCounts[assignedCluster] += 1
for i in range(clusters.shape[0]):
if clusterCounts[i] != 0.0:
newClusters[i] /= clusterCounts[i]
return newClusters
# Get data and perform init
data = np.genfromtxt('test_data.tsv', delimiter=',')
x_range = (np.amin(data[:, 0]), np.amax(data[:, 0]))
y_range = (np.amin(data[:, 1]), np.amax(data[:, 1]))
x_space = np.linspace(x_range[0], x_range[1])
y_space = np.linspace(y_range[0], y_range[1])
# Random initial clusters
num_clusters = 3
clusters = np.array([[np.random.choice(x_space), np.random.choice(y_space)] for
i in range(num_clusters)])
showKMeans(data, clusters)
# Functions for dealing with arbitrary dimensions
def get_bounds(data):
return [(np.amin(data[:, i]), np.amax(data[:, i])) for i in
range(data.shape[1])]
def get_random_centroids(bounds, num_centroids):
return [[np.random.uniform(dim_bound[0], dim_bound[1]) for dim_bound in
bounds] for i in range(num_centroids)]
# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4
template<typename MetricType, typename MatType>
double ParallelNaiveKMeans<MetricType, MatType>::Iterate(const arma::mat& centroids,
arma::mat& newCentroids,
arma::Col<size_t>& counts)
{
newCentroids.zeros(centroids.n_rows, centroids.n_cols);
counts.zeros(centroids.n_cols);
// Find the closest centroid to each point and update the new centroids.
// Computed in parallel over the complete dataset
#pragma omp parallel
{
arma::mat localCentroids(centroids.n_rows, centroids.n_cols,
arma::fill::zeros);
arma::Col<size_t> localCounts(centroids.n_cols, arma::fill::zeros);
#pragma omp for
for (size_t i = 0; i < dataset.n_cols; i++)
{
// Find the closest centroid to this point.
double minDistance = std::numeric_limits<double>::infinity();
size_t closestCluster = centroids.n_cols; // Invalid value.
for (size_t j = 0; j < centroids.n_cols; j++)
{
const double distance = metric.Evaluate(dataset.col(i),
centroids.col(j));
if (distance < minDistance)
{
minDistance = distance;
closestCluster = j;
}
}
Log::Assert(closestCluster != centroids.n_cols);
// We now have the minimum distance centroid index. Update that centroid.
localCentroids.col(closestCluster) += arma::vec(dataset.col(i));
localCounts(closestCluster)++;
}
#pragma omp critical
{
newCentroids += localCentroids;
counts += localCounts;
}
}
// Now normalize the centroid.
for (size_t i = 0; i < centroids.n_cols; ++i)
if (counts(i) != 0)
newCentroids.col(i) /= counts(i);
distanceCalculations += centroids.n_cols * dataset.n_cols;
// Calculate cluster distortion for this iteration.
double cNorm = 0.0;
for (size_t i = 0; i < centroids.n_cols; ++i)
{
cNorm += std::pow(metric.Evaluate(centroids.col(i), newCentroids.col(i)),
2.0);
}
distanceCalculations += centroids.n_cols;
return std::sqrt(cNorm);
}
---------------------------------------------
Sequential run
---------------------------------------------
λ mlpack_spike/kmeans ∴ perf stat ./a.out
Time for iterations : 1660771us
Performance counter stats for './a.out':
1999.658210 task-clock (msec) # 1.112 CPUs utilized
138 context-switches # 0.069 K/sec
1 cpu-migrations # 0.001 K/sec
2,023 page-faults # 0.001 M/sec
7,38,08,22,326 cycles # 3.691 GHz
19,96,78,28,391 instructions # 2.71 insn per cycle
3,27,83,44,603 branches # 1639.452 M/sec
91,58,549 branch-misses # 0.28% of all branches
1.798456669 seconds time elapsed
---------------------------------------------
Parallel run
---------------------------------------------
λ mlpack_spike/kmeans ∴ perf stat ./a.out
Time for iterations : 713074us
Performance counter stats for './a.out':
3160.809805 task-clock (msec) # 3.602 CPUs utilized
503 context-switches # 0.159 K/sec
1 cpu-migrations # 0.000 K/sec
2,050 page-faults # 0.649 K/sec
11,27,36,59,889 cycles # 3.567 GHz
19,65,34,07,230 instructions # 1.74 insn per cycle
3,22,94,02,213 branches # 1021.701 M/sec
91,82,426 branch-misses # 0.28% of all branches
0.877498263 seconds time elapsed
@shikharbhardwaj
Copy link
Author

The parallelization adds around 1000ms of synchronisation time, but due to the 4 threads, the time reduces by around 50%.

This would work even better for larger datasets, the Adult dataset has only around 32,000 instances.

@shikharbhardwaj
Copy link
Author

I tested this implementation for the points in src/mlpack/tests/kmeans_test.cpp and the cost converged in the same manner as the sequential implementation. I couldn't find another way to test for the convergence of K-Means. The reference I used : https://stats.stackexchange.com/questions/188087/proof-of-convergence-of-k-means

@mentekid
Copy link

Those are some neat results. Nice job!

The synchronization time is probably due to the critical section, spawning threads should not take that much. You can try replacing the critical section with a #pragma omp barrier followed by a #pragma omp once directive. Those calculations only need to happen by one thread, so that may remove some of the locking.

Take a look at tests/kmeans_test.cpp for more methods used to test kmeans. I think a valid test would be running the same clustering problem twice, once with one thread and once with multiple, and verify we get the same clusters both times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment