Skip to content

Instantly share code, notes, and snippets.

@SabraO

SabraO/kmeans.py Secret

Created Nov 6, 2017
Embed
What would you like to do?
import torch
import numpy as np
from datetime import datetime
n_clusters = 500 # Max Possible 485 (32-bit)
n_data_samples = 2950000 # Max possible 2M (32-bit)
n_iterations = 2
with torch.cuda.device(0):
# Generate sample values
samples_vec = torch.randn(n_data_samples, 1).cuda().half()
# Initialize cluster means
mean_vec = torch.randn(1, n_clusters).cuda().half()
# Initialize variables
U_zeros = torch.zeros(n_data_samples, n_clusters).cuda().half() # Empty membership matrix
iteration = 0
while iteration < n_iterations:
start = datetime.now()
print(iteration)
dist = samples_vec - mean_vec
# Per row of dist_sq get the index of the min value
min_indices = torch.min((dist*dist), 1)[1]
del dist
# Change dim of tensor to n_data_samplesx1
min_indices = min_indices.view([n_data_samples, 1])
# Empty clusters could be present in membership matrix
U = U_zeros.scatter_(1, min_indices, 1)
# Empty Cluster Check
empty_cluster_check = [] # number of zeros
for i in range(n_clusters):
n_zero_elem_i = n_data_samples - torch.numel(torch.nonzero(U[:, i]))
empty_cluster_check.append(n_zero_elem_i)
del U
end = datetime.now()
start2 = datetime.now()
##################################
## This section executes in CPU
#################################
# index_of_max_nnz_el = empty_cluster_check.index(min(empty_cluster_check))
indices = np.argwhere(np.array(empty_cluster_check) == n_data_samples)
len_indices = len(indices)
min_indices_arr = min_indices.cpu().numpy()
if len_indices > 0:
min_zeros = min(empty_cluster_check)
min_zeros_index = np.where(np.array(empty_cluster_check) == min_zeros)
extra_points = n_data_samples - min_zeros - 1
min_data_points = np.where(min_indices_arr == min_zeros_index[0][0])[0]
if len_indices <= extra_points:
index = 0
for i in indices:
min_indices_arr[min_data_points[index]] = i[0]
index += 1
empty_cluster_check[i[0]] = empty_cluster_check[i[0]] - 1
empty_cluster_check[min_zeros_index[0][0]] = empty_cluster_check[min_zeros_index[0][0]] + 1
else:
print("Error")
##################################
## End of CPU section
##################################
del empty_cluster_check
end2 = datetime.now()
time2 = (end2 - start2).total_seconds() * 1000
min_indices = torch.from_numpy(min_indices_arr).cuda()
min_indices = min_indices.view([n_data_samples, 1])
print("CPU time: " + str(time2))
start1 = datetime.now()
# Update membership matrix, no empty clusters
U = U_zeros.scatter_(1, min_indices, 1)
del min_indices
n_cluster_points = []
# Get number of data points in each cluster
for i in range(n_clusters):
n_points = torch.numel(torch.nonzero(U[:, i]))
n_cluster_points.append(n_points)
# Create tensor
n_cluster_points_tensor = torch.Tensor(n_cluster_points).cuda().half()
del n_cluster_points
# Update the means
cluster_mat_col_wise_sum = torch.sum((U*samples_vec), 0)
new_means = torch.div(cluster_mat_col_wise_sum, n_cluster_points_tensor)
del cluster_mat_col_wise_sum
del n_cluster_points_tensor
mean_vec = new_means.view(1, n_clusters)
del new_means
end1 = datetime.now()
time_elapsed = ((end - start) + (end1 - start1)).total_seconds() * 1000
iteration = iteration + 1
print("Iteration: " + str(iteration) + " Time: " + str(time_elapsed))
print("Final Mean Vector")
print(mean_vec)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment