Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created September 13, 2022 13:53
Show Gist options
  • Save SannaPersson/20667ad11c0db9348d1ea3ba6ea23c9b to your computer and use it in GitHub Desktop.
Save SannaPersson/20667ad11c0db9348d1ea3ba6ea23c9b to your computer and use it in GitHub Desktop.
variational_autoencoder5
def inference(digit, num_examples=1):
"""
Generates (num_examples) of a particular digit.
Specifically we extract an example of each digit,
then after we have the mu, sigma representation for
each digit we can sample from that.
After we sample we can run the decoder part of the VAE
and generate examples.
"""
images = []
idx = 0
for x, y in dataset:
if y == idx:
images.append(x)
idx += 1
if idx == 10:
break
encodings_digit = []
for d in range(10):
with torch.no_grad():
mu, sigma = model.encode(images[d].view(1, 784))
encodings_digit.append((mu, sigma))
mu, sigma = encodings_digit[digit]
for example in range(num_examples):
epsilon = torch.randn_like(sigma)
z = mu + sigma * epsilon
out = model.decode(z)
out = out.view(-1, 1, 28, 28)
save_image(out, f"generated_{digit}_ex{example}.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment