Skip to content

Instantly share code, notes, and snippets.

@asears
Forked from lzhbrian/gen_video_interpolation.py
Created February 13, 2020 02:14
Show Gist options
  • Save asears/efb0856c203b23ca5b4b524076cc8f7f to your computer and use it in GitHub Desktop.
Save asears/efb0856c203b23ca5b4b524076cc8f7f to your computer and use it in GitHub Desktop.
generate interpolation video from stylegan2
"""
Author: lzhbrian (https://lzhbrian.me)
Date: 2020.1.20
Note: mainly modified from: https://github.com/tkarras/progressive_growing_of_gans/blob/master/util_scripts.py#L50
"""
import numpy as np
from PIL import Image
import os
import scipy
import pickle
import moviepy
import dnnlib
import dnnlib.tflib as tflib
from tqdm import tqdm
tflib.init_tf()
fpath = '/nvme/linziheng/projects/stylegan2/results/20200118-stylegan2-all_valid_img_plain_15-8gpu-config-f/network-snapshot-006316.pkl'
with open(fpath, 'rb') as stream:
_G, _D, Gs = pickle.load(stream, encoding='latin1')
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
def create_image_grid(images, grid_size=None):
assert images.ndim == 3 or images.ndim == 4
num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2]
if grid_size is not None:
grid_w, grid_h = tuple(grid_size)
else:
grid_w = max(int(np.ceil(np.sqrt(num))), 1)
grid_h = max((num - 1) // grid_w + 1, 1)
grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype)
for idx in range(num):
x = (idx % grid_w) * img_w
y = (idx // grid_w) * img_h
grid[..., y : y + img_h, x : x + img_w] = images[idx]
return grid
def generate_interpolation_video(truncation_psi=0.5,
grid_size=[1,1], image_shrink=1, image_zoom=1,
duration_sec=60.0, smoothing_sec=1.0,
mp4='test-lerp.mp4', mp4_fps=30,
mp4_codec='libx264', mp4_bitrate='16M',
random_seed=1000):
num_frames = int(np.rint(duration_sec * mp4_fps))
random_state = np.random.RandomState(random_seed)
print('Generating latent vectors...')
shape = [num_frames, np.prod(grid_size)] + Gs.input_shape[1:] # [frame, image, channel, component]
all_latents = random_state.randn(*shape).astype(np.float32)
all_latents = scipy.ndimage.gaussian_filter(all_latents, [smoothing_sec * mp4_fps] + [0] * len(Gs.input_shape), mode='wrap')
all_latents /= np.sqrt(np.mean(np.square(all_latents)))
# Frame generation func for moviepy.
def make_frame(t):
frame_idx = int(np.clip(np.round(t * mp4_fps), 0, num_frames - 1))
latents = all_latents[frame_idx]
labels = np.zeros([latents.shape[0], 0], np.float32)
images = Gs.run(latents, None, truncation_psi=truncation_psi, randomize_noise=False, output_transform=fmt)
images = images.transpose(0, 3, 1, 2) #NHWC -> NCHW
grid = create_image_grid(images, grid_size).transpose(1, 2, 0) # HWC
if image_zoom > 1:
grid = scipy.ndimage.zoom(grid, [image_zoom, image_zoom, 1], order=0)
if grid.shape[2] == 1:
grid = grid.repeat(3, 2) # grayscale => RGB
return grid
# Generate video.
import moviepy.editor # pip install moviepy
c = moviepy.editor.VideoClip(make_frame, duration=duration_sec)
c.write_videofile(mp4, fps=mp4_fps, codec=mp4_codec, bitrate=mp4_bitrate)
return c
generate_interpolation_video()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment