Skip to content

Instantly share code, notes, and snippets.

@rolux

rolux/age.py Secret

Created December 29, 2019 13:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save rolux/8073400130c91121561e3dc7cfd5421f to your computer and use it in GitHub Desktop.
Save rolux/8073400130c91121561e3dc7cfd5421f to your computer and use it in GitHub Desktop.
# 1. Set up StyleGAN
import dnnlib
import dnnlib.tflib as tflib
import pretrained_networks
network_pkl = 'gdrive:networks/stylegan2-ffhq-config-f.pkl'
_G, _D, Gs = pretrained_networks.load_networks(network_pkl)
Gs_kwargs = dnnlib.EasyDict()
Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_kwargs.randomize_noise = False
Gs_syn_kwargs = dnnlib.EasyDict()
Gs_syn_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
Gs_syn_kwargs.randomize_noise = False
Gs_syn_kwargs.minibatch_size = 4
noise_vars = [
var for name, var in Gs.components.synthesis.vars.items()
if name.startswith('noise')
]
w_avg = Gs.get_var('dlatent_avg')
truncation_psi = 0.75
# 2. Get a vector
# https://github.com/a312863063/generators-with-stylegan2/blob/master/latent_directions/age.npy
# 3. Render results
import numpy as np
import PIL.Image
z = np.random.RandomState(5616).randn(1, 512)
w = Gs.components.mapping.run(z, None)
w = w_avg + (w - w_avg) * truncation_psi
v_age = np.load('age.npy')
n = 5
size = 256
canvas = PIL.Image.new('RGB', (n * size, size))
for i, v in enumerate(np.linspace(-10, 10, n)):
w_age = w + v * v_age
image = Gs.components.synthesis.run(w_age, **Gs_syn_kwargs)[0]
image = PIL.Image.fromarray(image)
image = image.resize((size, size), PIL.Image.LANCZOS)
canvas.paste(image, (i * size, 0))
canvas.save('age.png')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment