Created
September 13, 2022 23:14
-
-
Save thomasahle/a5e02b1f98d2c219045694539eabac78 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
class TestModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.a = nn.Linear(10, 10) | |
self.b = SubTestModule() | |
class SubTestModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.c = nn.EmbeddingBag(10, 10) | |
self.d = nn.Sequential( | |
nn.Linear(10, 10), | |
nn.ReLU(), | |
) | |
def _addindent(s_, numSpaces): | |
s = s_.split('\n') | |
# don't do anything for single-line stuff | |
if len(s) == 1: | |
return s_ | |
first = s.pop(0) | |
s = [(numSpaces * ' ') + line for line in s] | |
s = '\n'.join(s) | |
s = first + '\n' + s | |
return s | |
def size_tree(module): | |
my_size = 0 | |
for k, v in module._parameters.items(): | |
my_size += v.nelement() | |
child_lines = [] | |
child_sizes = 0 | |
for key, sub_module in module._modules.items(): | |
mod_str, child_size = size_tree(sub_module) | |
mod_str = _addindent(mod_str, 2) | |
child_lines.append('(' + key + '): ' + mod_str) | |
child_sizes += child_size | |
total_size = my_size + child_sizes | |
extra_lines = [f'{total_size=}, {my_size=}'] | |
lines = extra_lines + child_lines | |
main_str = module._get_name() + '(' | |
if lines: | |
# simple one-liner info, which most builtin Modules will use | |
if len(extra_lines) == 1 and not child_lines: | |
main_str += extra_lines[0] | |
else: | |
main_str += '\n ' + '\n '.join(lines) + '\n' | |
main_str += ')' | |
return main_str, total_size | |
module = TestModule() | |
tree, size = size_tree(module) | |
print(tree) | |
print(f'{size=}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment