Skip to content

Instantly share code, notes, and snippets.

@rasmusbergpalm
Last active September 21, 2021 07:29
Show Gist options
  • Save rasmusbergpalm/a50e413fd0c2e083ff99502f96db7572 to your computer and use it in GitHub Desktop.
Save rasmusbergpalm/a50e413fd0c2e083ff99502f96db7572 to your computer and use it in GitHub Desktop.
Pytorch Residual Module for use with nn.Sequential
class Residual(t.nn.Module):
def __init__(self, *args: t.nn.Module):
super().__init__()
self.delegate = t.nn.Sequential(*args)
def forward(self, inputs):
return self.delegate(inputs) + inputs
net = t.nn.Sequential(
t.nn.Conv2d(z_size, hid, 3, padding=1),
t.nn.ELU(),
Residual(
t.nn.Conv2d(hid, hid, 1),
t.nn.ELU(),
t.nn.Conv2d(hid, hid, 1),
),
Residual(
t.nn.Conv2d(hid, hid, 1),
t.nn.ELU(),
t.nn.Conv2d(hid, hid, 1),
),
Residual(
t.nn.Conv2d(hid, hid, 1),
t.nn.ELU(),
t.nn.Conv2d(hid, hid, 1),
),
Residual(
t.nn.Conv2d(hid, hid, 1),
t.nn.ELU(),
t.nn.Conv2d(hid, hid, 1),
),
Residual(
t.nn.Conv2d(hid, hid, 1),
t.nn.ELU(),
t.nn.Conv2d(hid, hid, 1),
),
Residual(
t.nn.Conv2d(hid, hid, 1),
t.nn.ELU(),
t.nn.Conv2d(hid, hid, 1),
),
t.nn.Conv2d(hid, z_size, 1)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment