Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created September 13, 2022 13:53
Show Gist options
  • Save SannaPersson/c48c87de13f5040b97707bc07b9ee74f to your computer and use it in GitHub Desktop.
Save SannaPersson/c48c87de13f5040b97707bc07b9ee74f to your computer and use it in GitHub Desktop.
variational_autoencoder1
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, z_dim, h_dim=200):
super().__init__()
# encoder
self.img_2hid = nn.Linear(input_dim, h_dim)
# one for mu and one for stds, note how we only output
# diagonal values of covariance matrix. Here we assume
# the pixels are conditionally independent
self.hid_2mu = nn.Linear(h_dim, z_dim)
self.hid_2sigma = nn.Linear(h_dim, z_dim)
# decoder
self.z_2hid = nn.Linear(z_dim, h_dim)
self.hid_2img = nn.Linear(h_dim, input_dim)
def encode(self, x):
h = F.relu(self.img_2hid(x))
mu = self.hid_2mu(h)
sigma = self.hid_2sigma(h)
return mu, sigma
def decode(self, z):
new_h = F.relu(self.z_2hid(z))
x = torch.sigmoid(self.hid_2img(new_h))
return x
def forward(self, x):
mu, sigma = self.encode(x)
# Sample from latent distribution from encoder
epsilon = torch.randn_like(sigma)
z_reparametrized = mu + sigma*epsilon
x = self.decode(z_reparametrized)
return x, mu, sigma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment