Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created September 13, 2022 14:03
Show Gist options
  • Save SannaPersson/de3a4a87f0a51cdcfb24ea27874cb678 to your computer and use it in GitHub Desktop.
Save SannaPersson/de3a4a87f0a51cdcfb24ea27874cb678 to your computer and use it in GitHub Desktop.
variational_autoencoder5
def inference(digit, num_examples=1):
"""
Generates (num_examples) of a particular digit.
Specifically we extract an example of each digit,
then after we have the mu, sigma representation for
each digit we can sample from that.
After we sample we can run the decoder part of the VAE
and generate examples.
"""
images = []
idx = 0
for x, y in dataset:
if y == idx:
images.append(x)
idx += 1
if idx == 10:
break
encodings_digit = []
for d in range(10):
with torch.no_grad():
mu, sigma = model.encode(images[d].view(1, 784))
encodings_digit.append((mu, sigma))
mu, sigma = encodings_digit[digit]
for example in range(num_examples):
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
out = model.decode(z)
out = out.view(-1, 1, 28, 28)
save_image(out, f"generated_{digit}_ex{example}.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment