Skip to content

Instantly share code, notes, and snippets.

@zackenton
Created November 27, 2017 20:39
Show Gist options
  • Save zackenton/12a86b6e0ff274b39608e40f4a412f2b to your computer and use it in GitHub Desktop.
Save zackenton/12a86b6e0ff274b39608e40f4a412f2b to your computer and use it in GitHub Desktop.
Count number trainable parameters in a pytorch model
def pytorch_count_params(model):
"count number trainable parameters in a pytorch model"
total_params = sum(reduce( lambda a, b: a*b, x.size()) for x in model.parameters())
return total_params
@ivanvoid
Copy link

ivanvoid commented Aug 24, 2020

You can find reduce in from functools import reduce
I assume
or you can use simpler version (without reduce):
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment