Skip to content

Instantly share code, notes, and snippets.

@Akash-Rawat
Last active July 2, 2021 10:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Akash-Rawat/e34521d55e7cdf1a04c556c7f7c35d93 to your computer and use it in GitHub Desktop.
Save Akash-Rawat/e34521d55e7cdf1a04c556c7f7c35d93 to your computer and use it in GitHub Desktop.
Defining Decoder
class Decoder(nn.Module):
def __init__(self, in_channels, out_channels, image_dim):
super().__init__()
iW, iH = image_dim
hW, hH = iW//POOLING_FACTOR, iH//POOLING_FACTOR
self.layer4 = nn.Sequential(
nn.Unflatten(1, unflattened_size=(in_channels, hW, hH)),
ConvTransposeLeak(in_channels=in_channels, out_channels=128)
)
self.layer3 = nn.Sequential(
ConvTransposeLeak(128, 128),
ConvTransposeLeak(128, 84)
)
self.layer2 = nn.Sequential(
ConvTransposeLeak(84, 84),
ConvTransposeLeak(84, 48)
)
self.layer1 = nn.Sequential(
ConvTransposeLeak(48, 48),
ConvTransposeLeak(48, 3)
)
self.unpooling = nn.MaxUnpool2d(4)
self.unpooling_2 = nn.MaxUnpool2d(2)
self.precision = nn.Parameter(torch.rand(1))
def generate_data(self, mean, precision):
sigma = torch.exp(-precision)
epsilon = torch.randn_like(mean)
return (sigma * epsilon) + mean
def forward(self, x, indices_1, indices_2, indices_3):
x = self.layer4(x)
x = self.unpooling_2(x, indices_3)
x = self.layer3(x)
x = self.unpooling(x, indices_2)
x = self.layer2(x)
x = self.unpooling(x, indices_1)
x = self.layer1(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment