Skip to content

Instantly share code, notes, and snippets.

@hotohoto
Last active October 12, 2022 02:48
Show Gist options
  • Save hotohoto/fe5dcd1ed25feb7f3a79d9a8607e4ef2 to your computer and use it in GitHub Desktop.
Save hotohoto/fe5dcd1ed25feb7f3a79d9a8607e4ef2 to your computer and use it in GitHub Desktop.
# grids = [create_grid(generated_batch[j]) for j in range(len(generated_batch))]
# save_as_gif_animation(grids, output_animation_path)
import math
import numpy as np
from PIL import Image
def create_grid(image_arrays: np.ndarray):
batch_size = image_arrays.shape[0]
image_height = image_arrays.shape[1]
image_width = image_arrays.shape[2]
n_channels = image_arrays.shape[3]
n_cols = n_rows = math.ceil(batch_size**0.5)
n_samples = min(n_rows * n_cols, batch_size)
canvas_height = image_height * n_rows
canvas_width = image_width * n_cols
canvas = np.zeros((canvas_height, canvas_width, n_channels))
for i in range(n_samples):
row = i // n_cols
col = i % n_rows
canvas[
row * image_height : (row + 1) * image_height,
col * image_width : (col + 1) * image_width,
:,
] = image_arrays[i]
return canvas
def save_as_gif_animation(images, file_path):
channels = images.shape[-1]
if channels == 1:
images = [img.squeeze() for img in images]
Image.fromarray(images[0]).save(
file_path,
save_all=True,
duration=10,
append_images=[Image.fromarray(arr) for arr in images[1:]],
loop=1,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment