Skip to content

Instantly share code, notes, and snippets.

@bh1995
Created June 30, 2020 09:55
Show Gist options
  • Save bh1995/450985d51b9f213bbb7d87f65d3a6ad8 to your computer and use it in GitHub Desktop.
Save bh1995/450985d51b9f213bbb7d87f65d3a6ad8 to your computer and use it in GitHub Desktop.
norm_layer = nn.InstanceNorm2d
class ResBlock(nn.Module):
def __init__(self, f):
super(ResBlock, self).__init__()
self.conv = nn.Sequential(nn.Conv2d(f, f, 3, 1, 1), norm_layer(f), nn.ReLU(),
nn.Conv2d(f, f, 3, 1, 1))
self.norm = norm_layer(f)
def forward(self, x):
return F.relu(self.norm(self.conv(x)+x))
class Generator(nn.Module):
def __init__(self, f=64, blocks=6):
super(Generator, self).__init__()
layers = [nn.ReflectionPad2d(3),
nn.Conv2d( 3, f, 7, 1, 0), norm_layer( f), nn.ReLU(True),
nn.Conv2d( f, 2*f, 3, 2, 1), norm_layer(2*f), nn.ReLU(True),
nn.Conv2d(2*f, 4*f, 3, 2, 1), norm_layer(4*f), nn.ReLU(True)]
for i in range(int(blocks)):
layers.append(ResBlock(4*f))
layers.extend([
nn.ConvTranspose2d(4*f, 4*2*f, 3, 1, 1), nn.PixelShuffle(2), norm_layer(2*f), nn.ReLU(True),
nn.ConvTranspose2d(2*f, 4*f, 3, 1, 1), nn.PixelShuffle(2), norm_layer( f), nn.ReLU(True),
nn.ReflectionPad2d(3), nn.Conv2d(f, 3, 7, 1, 0),
nn.Tanh()])
self.conv = nn.Sequential(*layers)
def forward(self, x):
return self.conv(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment