Skip to content

Instantly share code, notes, and snippets.

@pbaylies
Last active October 24, 2021 05:04
Show Gist options
  • Save pbaylies/43cab2c46c4cc46e34c3989313a32d3a to your computer and use it in GitHub Desktop.
Save pbaylies/43cab2c46c4cc46e34c3989313a32d3a to your computer and use it in GitHub Desktop.
WikiArt Example Generation Updated
!git clone https://github.com/dvschultz/stylegan3
!mv stylegan3/* .
!wget https://archive.org/download/wikiart-stylegan2-conditional-model/WikiArt5.pkl
!pip install ninja
import sys
sys.path.insert(0,'/content/')
import ipywidgets as widgets
import pickle
import math
import random
import PIL.Image
import numpy as np
import torch
import pickle
import dnnlib
import legacy
#
# Specify the location of the WikiArt pretrained model
# in Google Drive
#
# network_pkl = '/content/drive/MyDrive/pretrained/WikiArt5.pkl'
network_pkl = 'WikiArt5.pkl'
device = torch.device('cuda')
#with dnnlib.util.open_url(network_pkl) as f:
with open(network_pkl, 'rb') as f:
Gs = legacy.load_network_pkl(f)['G_ema'].requires_grad_(False).to(device)
Gs_syn_kwargs = dnnlib.EasyDict()
batch_size = 8
Gs_syn_kwargs.noise_mode = 'random'
artist = widgets.Dropdown(
options=[('Unknown Artist', 0), ('Boris Kustodiev', 1), ('Camille Pissarro', 2), ('Childe Hassam', 3), ('Claude Monet', 4), ('Edgar Degas', 5), ('Eugene Boudin', 6), ('Gustave Dore', 7), ('Ilya Repin', 8), ('Ivan Aivazovsky', 9), ('Ivan Shishkin', 10), ('John Singer Sargent', 11), ('Marc Chagall', 12), ('Martiros Saryan', 13), ('Nicholas Roerich', 14), ('Pablo Picasso', 15), ('Paul Cezanne', 16), ('Pierre Auguste Renoir', 17), ('Pyotr Konchalovsky', 18), ('Raphael Kirchner', 19), ('Rembrandt', 20), ('Salvador Dali', 21), ('Vincent Van Gogh', 22), ('Hieronymus Bosch', 23), ('Leonardo Da Vinci', 24), ('Albrecht Durer', 25), ('Edouard Cortes', 26), ('Sam Francis', 27), ('Juan Gris', 28), ('Lucas Cranach The Elder', 29), ('Paul Gauguin', 30), ('Konstantin Makovsky', 31), ('Egon Schiele', 32), ('Thomas Eakins', 33), ('Gustave Moreau', 34), ('Francisco Goya', 35), ('Edvard Munch', 36), ('Henri Matisse', 37), ('Fra Angelico', 38), ('Maxime Maufra', 39), ('Jan Matejko', 40), ('Mstislav Dobuzhinsky', 41), ('Alfred Sisley', 42), ('Mary Cassatt', 43), ('Gustave Loiseau', 44), ('Fernando Botero', 45), ('Zinaida Serebriakova', 46), ('Georges Seurat', 47), ('Isaac Levitan', 48), ('Joaquã­n Sorolla', 49), ('Jacek Malczewski', 50), ('Berthe Morisot', 51), ('Andy Warhol', 52), ('Arkhip Kuindzhi', 53), ('Niko Pirosmani', 54), ('James Tissot', 55), ('Vasily Polenov', 56), ('Valentin Serov', 57), ('Pietro Perugino', 58), ('Pierre Bonnard', 59), ('Ferdinand Hodler', 60), ('Bartolome Esteban Murillo', 61), ('Giovanni Boldini', 62), ('Henri Martin', 63), ('Gustav Klimt', 64), ('Vasily Perov', 65), ('Odilon Redon', 66), ('Tintoretto', 67), ('Gene Davis', 68), ('Raphael', 69), ('John Henry Twachtman', 70), ('Henri De Toulouse Lautrec', 71), ('Antoine Blanchard', 72), ('David Burliuk', 73), ('Camille Corot', 74), ('Konstantin Korovin', 75), ('Ivan Bilibin', 76), ('Titian', 77), ('Maurice Prendergast', 78), ('Edouard Manet', 79), ('Peter Paul Rubens', 80), ('Aubrey Beardsley', 81), ('Paolo Veronese', 82), ('Joshua Reynolds', 83), ('Kuzma Petrov Vodkin', 84), ('Gustave Caillebotte', 85), ('Lucian Freud', 86), ('Michelangelo', 87), ('Dante Gabriel Rossetti', 88), ('Felix Vallotton', 89), ('Nikolay Bogdanov Belsky', 90), ('Georges Braque', 91), ('Vasily Surikov', 92), ('Fernand Leger', 93), ('Konstantin Somov', 94), ('Katsushika Hokusai', 95), ('Sir Lawrence Alma Tadema', 96), ('Vasily Vereshchagin', 97), ('Ernst Ludwig Kirchner', 98), ('Mikhail Vrubel', 99), ('Orest Kiprensky', 100), ('William Merritt Chase', 101), ('Aleksey Savrasov', 102), ('Hans Memling', 103), ('Amedeo Modigliani', 104), ('Ivan Kramskoy', 105), ('Utagawa Kuniyoshi', 106), ('Gustave Courbet', 107), ('William Turner', 108), ('Theo Van Rysselberghe', 109), ('Joseph Wright', 110), ('Edward Burne Jones', 111), ('Koloman Moser', 112), ('Viktor Vasnetsov', 113), ('Anthony Van Dyck', 114), ('Raoul Dufy', 115), ('Frans Hals', 116), ('Hans Holbein The Younger', 117), ('Ilya Mashkov', 118), ('Henri Fantin Latour', 119), ('M.C. Escher', 120), ('El Greco', 121), ('Mikalojus Ciurlionis', 122), ('James Mcneill Whistler', 123), ('Karl Bryullov', 124), ('Jacob Jordaens', 125), ('Thomas Gainsborough', 126), ('Eugene Delacroix', 127), ('Canaletto', 128)],
value=22,
description='Artist: '
)
genre = widgets.Dropdown(
options=[('Abstract Painting', 129), ('Cityscape', 130), ('Genre Painting', 131), ('Illustration', 132), ('Landscape', 133), ('Nude Painting', 134), ('Portrait', 135), ('Religious Painting', 136), ('Sketch And Study', 137), ('Still Life', 138), ('Unknown Genre', 139)],
value=139,
description='Genre: '
)
style = widgets.Dropdown(
options=[('Abstract Expressionism', 140), ('Action Painting', 141), ('Analytical Cubism', 142), ('Art Nouveau', 143), ('Baroque', 144), ('Color Field Painting', 145), ('Contemporary Realism', 146), ('Cubism', 147), ('Early Renaissance', 148), ('Expressionism', 149), ('Fauvism', 150), ('High Renaissance', 151), ('Impressionism', 152), ('Mannerism Late Renaissance', 153), ('Minimalism', 154), ('Naive Art Primitivism', 155), ('New Realism', 156), ('Northern Renaissance', 157), ('Pointillism', 158), ('Pop Art', 159), ('Post Impressionism', 160), ('Realism', 161), ('Rococo', 162), ('Romanticism', 163), ('Symbolism', 164), ('Synthetic Cubism', 165), ('Ukiyo-e', 166)],
value=160,
description='Style: '
)
seed = widgets.IntSlider(min=0, max=100000, step=1, value=9, description='Seed: ')
scale = widgets.FloatSlider(min=0, max=25, step=0.1, value=2, description='Global Scale: ')
truncation = widgets.FloatSlider(min=-2, max=2, step=0.1, value=1, description='Truncation: ')
variance = widgets.FloatSlider(min=-1, max=1, step=0.1, value=0.4, description='Variance: ')
iterations = widgets.IntSlider(min=0, max=100, step=1, value=64, description='Iterations: ')
top_box = widgets.HBox([artist, genre, style])
mid_box = widgets.HBox([variance, iterations])
bot_box = widgets.HBox([seed, scale, truncation])
ui = widgets.VBox([top_box, mid_box, bot_box])
def display_sample(artist, genre, style, variance, iterations, seed, scale, truncation):
batch_size = 1
l1 = np.zeros((1,167))
l1[0][artist] = 1.0
l1[0][genre] = 1.0
l1[0][style] = 1.0
l1 = scale * l1
all_seeds = [seed] * batch_size
all_z = np.stack([np.random.RandomState(seed).randn(Gs.mapping.z_dim) for seed in all_seeds]) # [minibatch, component]
l1 = torch.from_numpy(l1).to(device)
all_z = torch.from_numpy(all_z).to(device)
all_w = Gs.mapping(all_z, l1) # [minibatch, layer, component]
if truncation != 1:
w_avg = Gs.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
if variance == 0 or iterations < 1:
all_images = Gs.synthesis(all_w, **Gs_syn_kwargs)
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
else:
acc_w = torch.from_numpy(np.zeros((batch_size,18,512))).to(device)
total = 0.0
for i in range(iterations):
all_w = Gs.mapping(all_z + torch.from_numpy(variance*np.random.RandomState(i).randn(512)).to(device), torch.tile(l1 + torch.from_numpy(variance*np.random.RandomState(i).randn(167)).to(device), (batch_size, 1))) # [minibatch, layer, component]
if truncation != 1:
w_avg = Gs.mapping.w_avg
all_w = w_avg + (all_w - w_avg) * truncation # [minibatch, layer, component]
acc_w += all_w
total+=1.0
acc_w /= total
all_images = Gs.synthesis(acc_w, **Gs_syn_kwargs)
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
display(PIL.Image.fromarray(np.median(all_images, axis=0).astype(np.uint8)))
out = widgets.interactive_output(display_sample, {'artist': artist, 'genre': genre, 'style': style, 'seed': seed, 'variance': variance, 'iterations': iterations, 'scale': scale, 'truncation': truncation})
display(ui, out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment