Last active
January 21, 2020 04:52
-
-
Save danielegrattarola/8296b9fd29116443da74d0aa2519d7c3 to your computer and use it in GitHub Desktop.
Implementation of the Telestrations Neural Networks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
from PIL import Image | |
from pytorch_pretrained_biggan import BigGAN, one_hot_from_int, truncated_noise_sample, convert_to_images | |
from torchvision import models | |
from torchvision import transforms | |
def draw(label, truncation=1.): | |
# Create the inputs for the GAN | |
class_vector = one_hot_from_int([label], batch_size=1) | |
class_vector = torch.from_numpy(class_vector) | |
class_vector = class_vector.to('cuda') | |
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=1) | |
noise_vector = torch.from_numpy(noise_vector) | |
noise_vector = noise_vector.to('cuda') | |
# Generate image | |
with torch.no_grad(): | |
output = gan(noise_vector, class_vector, truncation) | |
output = output.to('cpu') | |
# Get a PIL image from a Torch tensor | |
img = convert_to_images(output) | |
return img | |
def guess(img, top=5): | |
# Pre-process image | |
img = transform(img[0]) | |
# Classify image | |
classification = classifier(img.unsqueeze(0)) | |
_, indices = torch.sort(classification, descending=True) | |
percentage = torch.nn.functional.softmax(classification, dim=1)[0] | |
# Get the global ImageNet class, labels, and the predicted probabilities | |
idxs = np.array([idx for idx in indices[0]][:top]) | |
labs = np.array([labels[idx] for idx in indices[0]][:top]) | |
probs = np.array([percentage[idx].item() for idx in indices[0]][:top]) | |
return idxs, labs, probs | |
iterations = 8 # How many players there are | |
standard_noise = 0.3 # Some random noise because people are not perfect | |
current_class = np.random.randint(0, 1001) # The secret source word is random | |
# Load ImageNet class list | |
with open('imagenet_classes.txt') as f: | |
labels = [line.strip() for line in f.readlines()] | |
gan = BigGAN.from_pretrained('biggan-deep-256') | |
gan.to('cuda') | |
classifier = models.resnet50(pretrained=True) | |
classifier.eval() # Do this to set the model to inference mode | |
transform = transforms.Compose([ | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
)]) | |
output_imgs = [] # Stores the drawings | |
output_labels = [] # Stores the guesses | |
output_labels.append(labels[current_class]) | |
# Main game loop | |
for i in range(iterations): | |
# Draw an image | |
img = draw(current_class) | |
output_imgs.append(img[0]) | |
# Guess what the image is | |
idxs, labs, probs = guess(img) | |
# Add noise | |
probs += np.random.uniform(0, standard_noise, size=probs.shape) | |
probs /= probs.sum() # Re-normalize because of noise | |
# Choose from the predictions | |
choice = np.random.choice(np.arange(len(labs)), p=probs) | |
current_class = idxs[choice] | |
output_labels.append(labs[choice]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment