Skip to content

Instantly share code, notes, and snippets.

@bentrevett
Created November 18, 2017 01:43
Show Gist options
  • Save bentrevett/f361e47b4735dd0e175e455cc8b73dd7 to your computer and use it in GitHub Desktop.
Save bentrevett/f361e47b4735dd0e175e455cc8b73dd7 to your computer and use it in GitHub Desktop.
def get_num_params(model):
param_count = 0
for param in model.parameters():
param_count += np.product(param.data.shape)
return param_count
def get_num_params(model):
return reduce(operator.add,
map(lambda param: reduce(operator.mul, [s for s in param.size()]),
model.parameters())))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment