-
-
Save SabraO/c59edbaeb9141d88db3f1e3a0e4d3ccb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from datetime import datetime | |
n_clusters = 500 # Max Possible 485 (32-bit) | |
n_data_samples = 2950000 # Max possible 2M (32-bit) | |
n_iterations = 2 | |
with torch.cuda.device(0): | |
# Generate sample values | |
samples_vec = torch.randn(n_data_samples, 1).cuda().half() | |
# Initialize cluster means | |
mean_vec = torch.randn(1, n_clusters).cuda().half() | |
# Initialize variables | |
U_zeros = torch.zeros(n_data_samples, n_clusters).cuda().half() # Empty membership matrix | |
iteration = 0 | |
while iteration < n_iterations: | |
start = datetime.now() | |
print(iteration) | |
dist = samples_vec - mean_vec | |
# Per row of dist_sq get the index of the min value | |
min_indices = torch.min((dist*dist), 1)[1] | |
del dist | |
# Change dim of tensor to n_data_samplesx1 | |
min_indices = min_indices.view([n_data_samples, 1]) | |
# Empty clusters could be present in membership matrix | |
U = U_zeros.scatter_(1, min_indices, 1) | |
# Empty Cluster Check | |
empty_cluster_check = [] # number of zeros | |
for i in range(n_clusters): | |
n_zero_elem_i = n_data_samples - torch.numel(torch.nonzero(U[:, i])) | |
empty_cluster_check.append(n_zero_elem_i) | |
del U | |
end = datetime.now() | |
start2 = datetime.now() | |
################################## | |
## This section executes in CPU | |
################################# | |
# index_of_max_nnz_el = empty_cluster_check.index(min(empty_cluster_check)) | |
indices = np.argwhere(np.array(empty_cluster_check) == n_data_samples) | |
len_indices = len(indices) | |
min_indices_arr = min_indices.cpu().numpy() | |
if len_indices > 0: | |
min_zeros = min(empty_cluster_check) | |
min_zeros_index = np.where(np.array(empty_cluster_check) == min_zeros) | |
extra_points = n_data_samples - min_zeros - 1 | |
min_data_points = np.where(min_indices_arr == min_zeros_index[0][0])[0] | |
if len_indices <= extra_points: | |
index = 0 | |
for i in indices: | |
min_indices_arr[min_data_points[index]] = i[0] | |
index += 1 | |
empty_cluster_check[i[0]] = empty_cluster_check[i[0]] - 1 | |
empty_cluster_check[min_zeros_index[0][0]] = empty_cluster_check[min_zeros_index[0][0]] + 1 | |
else: | |
print("Error") | |
################################## | |
## End of CPU section | |
################################## | |
del empty_cluster_check | |
end2 = datetime.now() | |
time2 = (end2 - start2).total_seconds() * 1000 | |
min_indices = torch.from_numpy(min_indices_arr).cuda() | |
min_indices = min_indices.view([n_data_samples, 1]) | |
print("CPU time: " + str(time2)) | |
start1 = datetime.now() | |
# Update membership matrix, no empty clusters | |
U = U_zeros.scatter_(1, min_indices, 1) | |
del min_indices | |
n_cluster_points = [] | |
# Get number of data points in each cluster | |
for i in range(n_clusters): | |
n_points = torch.numel(torch.nonzero(U[:, i])) | |
n_cluster_points.append(n_points) | |
# Create tensor | |
n_cluster_points_tensor = torch.Tensor(n_cluster_points).cuda().half() | |
del n_cluster_points | |
# Update the means | |
cluster_mat_col_wise_sum = torch.sum((U*samples_vec), 0) | |
new_means = torch.div(cluster_mat_col_wise_sum, n_cluster_points_tensor) | |
del cluster_mat_col_wise_sum | |
del n_cluster_points_tensor | |
mean_vec = new_means.view(1, n_clusters) | |
del new_means | |
end1 = datetime.now() | |
time_elapsed = ((end - start) + (end1 - start1)).total_seconds() * 1000 | |
iteration = iteration + 1 | |
print("Iteration: " + str(iteration) + " Time: " + str(time_elapsed)) | |
print("Final Mean Vector") | |
print(mean_vec) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment