Skip to content

Instantly share code, notes, and snippets.

@rijulg
Created June 7, 2020 06:31
Show Gist options
  • Save rijulg/38cdeec892a8cfbc3a841da4fbe0d517 to your computer and use it in GitHub Desktop.
Save rijulg/38cdeec892a8cfbc3a841da4fbe0d517 to your computer and use it in GitHub Desktop.
Pruning and converting sparse network to dense in pytorch
#!/usr/bin/python
import copy
import torch
import numpy as np
torch.manual_seed(0)
class Original(torch.nn.Module):
def __init__(self):
super(Original, self).__init__()
self.l1 = torch.nn.Linear(3, 3, bias=True)
def forward(self, x):
return self.l1(x)
class SparseLinear(torch.nn.Module):
def __init__(self, original):
super(SparseLinear, self).__init__()
nonzero_weight = (original.weight != 0)
needs_bias = original.bias is not None
self.linears = torch.nn.ModuleList()
for i, weight in enumerate(nonzero_weight):
capture_indices = weight.nonzero().squeeze()
l = torch.nn.Linear(weight.sum(), 1, bias=needs_bias)
l.weight.data = original.weight[i, capture_indices].view(-1)
if needs_bias:
l.bias.data = original.bias[i]
l.register_buffer('weight_mask', weight)
l.register_buffer('capture_indices', capture_indices)
self.linears.append(l)
def forward(self, x):
y = []
for linear in self.linears:
capture_indices = linear._buffers['capture_indices']
_x = x[capture_indices].view(-1)
_y = linear(_x)
y += [_y]
return torch.stack(y)
def prune(model):
k = 20
all_weights = []
for p in model.parameters():
if len(p.data.size()) != 1:
all_weights += list(p.cpu().data.abs().numpy().flatten())
threshold = np.percentile(np.array(all_weights), k)
for p in model.parameters():
if len(p.data.size()) != 1:
mask = p.data.abs() > threshold
mask = torch.autograd.Variable(mask, requires_grad=False, volatile=False)
p.data = p.data * mask.data.float()
def prune2dense(model):
for name, module in model.named_modules():
if hasattr(module, 'weight'):
setattr(model, name, SparseLinear(module))
def countparams(model):
total = sum(p.numel() for p in model.parameters())
withgrad = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters:: total: {total}, withgrad: {withgrad}")
def print_state_dict(model):
dic = model.state_dict()
print("\tstate_dict:: ", dic)
print()
original = Original()
original.eval()
x = torch.FloatTensor([1, 2, 3])
# print_state_dict(original)
countparams(original)
print("y::", original(x))
prune(original)
# print_state_dict(original)
countparams(original)
print("y::", original(x))
condensed = copy.deepcopy(original)
prune2dense(condensed)
condensed.eval()
# print_state_dict(condensed)
countparams(condensed)
print("y::", condensed(x))
##### Results
# Parameters:: total: 9, withgrad: 9
# y:: tensor([-0.8104, -0.4052, 0.7504], grad_fn=<SqueezeBackward3>)
# Parameters:: total: 9, withgrad: 9
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<SqueezeBackward3>)
# Parameters:: total: 7, withgrad: 7
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<StackBackward>)
##### Results
# Parameters:: total: 9, withgrad: 9
# y:: tensor([-0.8104, -0.4052, 0.7504], grad_fn=<SqueezeBackward3>)
# Parameters:: total: 9, withgrad: 9
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<SqueezeBackward3>)
# Parameters:: total: 7, withgrad: 7
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<StackBackward>)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment