Skip to content

Instantly share code, notes, and snippets.

@Multihuntr
Created April 24, 2020 10:38
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Multihuntr/72f625400139592a9138db8e8606989d to your computer and use it in GitHub Desktop.
Save Multihuntr/72f625400139592a9138db8e8606989d to your computer and use it in GitHub Desktop.
Generic Pytorch Module Wrapper - When nn.Sequential just isn't enough
# I keep properties on my main nn.Modules. e.g. a list of the training statistics the model is tracking.
# I wanted to perform a set of extra actions across multiple different modules without having to
# - write those steps into each of the 5+ different model definitions, or
# - explicitly expose those values on the wrapper module.
# It's fairly trivial, but if you don't use the try: super(), it doesn't keep the `wrapped` property.
import torch
import torch.nn as nn
class Wrapper(nn.Module):
def __init__(self, wrapped):
super().__init__()
self.wrapped = wrapped
def forward(self, x):
out = self.wrapped(x)
# insert fancy logic here
return out
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
if name == "wrapped":
raise AttributeError()
return getattr(self.wrapped, name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment