Created
March 16, 2018 06:37
-
-
Save HTLife/b6640af9d6e7d765411f8aa9aa94b837 to your computer and use it in GitHub Desktop.
Pytorch model summary
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def summary(input_size, model): | |
def register_hook(module): | |
def hook(module, input, output): | |
class_name = str(module.__class__).split('.')[-1].split("'")[0] | |
module_idx = len(summary) | |
m_key = '%s-%i' % (class_name, module_idx+1) | |
summary[m_key] = OrderedDict() | |
summary[m_key]['input_shape'] = list(input[0].size()) | |
summary[m_key]['input_shape'][0] = -1 | |
summary[m_key]['output_shape'] = list(output.size()) | |
summary[m_key]['output_shape'][0] = -1 | |
params = 0 | |
if hasattr(module, 'weight'): | |
params += th.prod(th.LongTensor(list(module.weight.size()))) | |
if module.weight.requires_grad: | |
summary[m_key]['trainable'] = True | |
else: | |
summary[m_key]['trainable'] = False | |
if hasattr(module, 'bias'): | |
params += th.prod(th.LongTensor(list(module.bias.size()))) | |
summary[m_key]['nb_params'] = params | |
if not isinstance(module, nn.Sequential) and \ | |
not isinstance(module, nn.ModuleList) and \ | |
not (module == model): | |
hooks.append(module.register_forward_hook(hook)) | |
dtype = th.cuda.FloatTensor | |
# check if there are multiple inputs to the network | |
if isinstance(input_size[0], (list, tuple)): | |
x = [Variable(th.rand(1,*in_size)).type(dtype) for in_size in input_size] | |
else: | |
x = Variable(th.rand(1,*input_size)).type(dtype) | |
print(x.shape) | |
print(type(x[0])) | |
# create properties | |
summary = OrderedDict() | |
hooks = [] | |
# register hook | |
model.apply(register_hook) | |
# make a forward pass | |
model(x) | |
# remove these hooks | |
for h in hooks: | |
h.remove() | |
print('----------------------------------------------------------------') | |
line_new = '{:>20} {:>25} {:>15}'.format('Layer (type)', 'Output Shpae', 'Param #') | |
print(line_new) | |
print('================================================================') | |
total_params = 0 | |
trainable_params = 0 | |
for layer in summary: | |
## input_shape, output_shape, trainable, nb_params | |
line_new = '{:>20} {:>25} {:>15}'.format(layer, summary[layer]['output_shape'], summary[layer]['nb_params']) | |
total_params += summary[layer]['nb_params'] | |
if 'trainable' in summary[layer]: | |
if summary[layer]['trainable'] == True: | |
trainable_params += summary[layer]['nb_params'] | |
print(line_new) | |
print('================================================================') | |
print('Total params: ' + str(total_params)) | |
print('Trainable params: ' + str(trainable_params)) | |
print('Non-trainable params: ' + str(total_params - trainable_params)) | |
print('----------------------------------------------------------------') | |
return summary |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I noticed that for some reason https://gist.github.com/HTLife/b6640af9d6e7d765411f8aa9aa94b837#file-summary-py-L60 fails when the output is a list because
{:>25}
is not available for list datatypes.I was able to fix this by just wrapping it with
str(.)
.