Skip to content

Instantly share code, notes, and snippets.

@thomasahle
Created September 13, 2022 23:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasahle/a5e02b1f98d2c219045694539eabac78 to your computer and use it in GitHub Desktop.
Save thomasahle/a5e02b1f98d2c219045694539eabac78 to your computer and use it in GitHub Desktop.
import torch.nn as nn
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = SubTestModule()
class SubTestModule(nn.Module):
def __init__(self):
super().__init__()
self.c = nn.EmbeddingBag(10, 10)
self.d = nn.Sequential(
nn.Linear(10, 10),
nn.ReLU(),
)
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def size_tree(module):
my_size = 0
for k, v in module._parameters.items():
my_size += v.nelement()
child_lines = []
child_sizes = 0
for key, sub_module in module._modules.items():
mod_str, child_size = size_tree(sub_module)
mod_str = _addindent(mod_str, 2)
child_lines.append('(' + key + '): ' + mod_str)
child_sizes += child_size
total_size = my_size + child_sizes
extra_lines = [f'{total_size=}, {my_size=}']
lines = extra_lines + child_lines
main_str = module._get_name() + '('
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str, total_size
module = TestModule()
tree, size = size_tree(module)
print(tree)
print(f'{size=}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment