Skip to content

Instantly share code, notes, and snippets.

@Guitaricet
Last active December 8, 2022 06:01
Show Gist options
  • Save Guitaricet/19f8ed789b22d508be353cf3d169af0c to your computer and use it in GitHub Desktop.
Save Guitaricet/19f8ed789b22d508be353cf3d169af0c to your computer and use it in GitHub Desktop.
Read pytorch module descriptions more easily
from torch.nn import ModuleList
def _addindent(s_, numSpaces):
s = s_.split('\n')
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def iterable_repr(module):
"""A custom repr for ModuleList that compresses repeated module representations"""
list_of_reprs = [repr(item) for item in module]
repeats = [1]
repeated_blocks = [list_of_reprs[0]]
for r in list_of_reprs[1:]:
if r == repeated_blocks[-1]:
repeats[-1] += 1
else:
repeats.append(1)
repeated_blocks.append(r)
lines = []
main_str = module._get_name() + '('
for r, b in zip(repeats, repeated_blocks):
local_repr = f"{r} x {b}"
local_repr = _addindent(local_repr, 2)
lines.append(local_repr)
main_str += '\n ' + '\n '.join(lines) + '\n'
main_str += ')'
return main_str
ModuleList.__repr__ = iterable_repr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment