Skip to content

Instantly share code, notes, and snippets.

@ajbrock
Last active March 14, 2018 23:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ajbrock/26ae0dfe80167667af010ae3708ca276 to your computer and use it in GitHub Desktop.
Save ajbrock/26ae0dfe80167667af010ae3708ca276 to your computer and use it in GitHub Desktop.
import torch
# Dict to store hooks and flop count
data_dict = {'conv_flops' : 0, 'hooks' :[]}
def count_conv_flops(self, input, output):
# Flop contribution from channelwise connections
flops_c = self.out_channels * self.in_channels / self.groups
# Flop contribution from number of spatial locations we convolve over
flops_s = output.size(2) * output.size(3)
# Flop contribution from number of mult-adds at each location
flops_f = self.kernel_size[0] * self.kernel_size[1]
data_dict['conv_flops'] += flops_c * flops_s * flops_f
return
def add_hooks(m):
if isinstance(m, torch.nn.Conv2d):
data_dict['hooks'] += [m.register_forward_hook(count_conv_flops)]
return
def count_flops(model, x):
data_dict['conv_flops'] = 0
# Note if we need to return the model to training mode
set_train = model.training
model.eval()
model.apply(add_hooks)
out = model(torch.autograd.Variable(x.data, volatile=True))
for hook in data_dict['hooks']:
hook.remove()
if set_train:
model.train()
return data_dict['conv_flops']
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment