Skip to content

Instantly share code, notes, and snippets.

@Akash-Rawat
Last active July 2, 2021 10:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Akash-Rawat/d7587e1653b3a01e905beef1a0ccc200 to your computer and use it in GitHub Desktop.
Save Akash-Rawat/d7587e1653b3a01e905beef1a0ccc200 to your computer and use it in GitHub Desktop.
Defining Encoder
class Encoder(nn.Module):
def __init__(self, in_channels, out_channels, image_dim, latent_dim):
super().__init__()
# constants used
iW, iH = image_dim
hW, hH = iW//POOLING_FACTOR, iH//POOLING_FACTOR
vec_dim = out_channels * hW * hH
self.layer1 = nn.Sequential(
ConvLeak(in_channels=in_channels, out_channels=48),
ConvLeak(in_channels=48, out_channels=48)
)
self.layer2 = nn.Sequential(
ConvLeak(in_channels=48, out_channels=84),
ConvLeak(in_channels=84, out_channels=84)
)
self.layer3 = nn.Sequential(
ConvLeak(in_channels=84, out_channels=128),
ConvLeak(in_channels=128, out_channels=128)
)
self.layer4 = nn.Sequential(
ConvLeak(in_channels=128, out_channels=out_channels),
nn.Flatten()
)
self.pooling = nn.MaxPool2d(4, return_indices=True)
self.pooling_2 = nn.MaxPool2d(2, return_indices=True)
self.hidden = nn.Sequential(
nn.Linear(in_features = vec_dim, out_features=latent_dim),
nn.LeakyReLU(),
nn.Linear(in_features=latent_dim, out_features=latent_dim),
nn.Tanh()
)
self.encoder_mean = nn.Linear(in_features = latent_dim, out_features = vec_dim)
self.encoder_logstd = nn.Linear(in_features = latent_dim, out_features = vec_dim)
def generate_code(self, mean, log_std):
sigma = torch.exp(log_std)
epsilon = torch.randn_like(mean)
return (sigma * epsilon) + mean
def forward(self, x):
x = self.layer1(x)
x, indices_1 = self.pooling(x)
x = self.layer2(x)
x, indices_2 = self.pooling(x)
x = self.layer3(x)
x, indices_3 = self.pooling_2(x)
x = self.layer4(x)
hidden = self.hidden(x)
mean, log_std = self.encoder_mean(hidden), self.encoder_logstd(hidden)
c = self.generate_code(mean, log_std)
return c, indices_1, indices_2, indices_3, mean, log_std
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment