Skip to content

Instantly share code, notes, and snippets.

@rrajj
Last active April 25, 2020 10:54
Show Gist options
  • Save rrajj/05d6011bab83236b7de096044064f954 to your computer and use it in GitHub Desktop.
Save rrajj/05d6011bab83236b7de096044064f954 to your computer and use it in GitHub Desktop.
import numpy as np
import torch
def sparse_retain(sparse_matrix, to_retain):
# if sparse_matrix.shape[0] != to_retain.shape[0]:
if len(sparse_matrix.coalesce().values()) != to_retain.shape[0]:
raise ValueError("Shape Not Matched!")
a_mat = torch.IntTensor([])
np_indices = np.empty((1, 2), int) # dtype = np.int32)
np_values = np.array([])
for i in range(len(to_retain)):
if to_retain[i] == True:
indices_ = [[int(sparse_matrix.coalesce().indices()[0][i]), int(sparse_matrix.coalesce().indices()[1][i])]]
np_indices = np.append(np_indices, indices_, axis = 0)
values_ = int(sparse_matrix.coalesce().values()[i])
np_values = np.append(np_values, values_)
np_indices = np.delete(np_indices, 0, axis = 0)
sp_indices = torch.from_numpy(np_indices.T)
sp_values = torch.from_numpy(np_values)
retain_matrix = torch.sparse_coo_tensor(sp_indices, sp_values, sparse_matrix.shape)
return retain_matrix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment