Skip to content

Instantly share code, notes, and snippets.

@kvablack
Created July 7, 2023 20:56
Show Gist options
  • Save kvablack/3e2d3cf849107f76b13d1e3cd31c0bf6 to your computer and use it in GitHub Desktop.
Save kvablack/3e2d3cf849107f76b13d1e3cd31c0bf6 to your computer and use it in GitHub Desktop.
DDPO animation using iceberg
from iceberg import Renderer, Bounds, Colors, PathStyle, Corner, FontStyle, Color, StrokeCap
from iceberg.primitives.layout import Directions, Anchor, Compose, Arrange, Align
from iceberg.arrows import Arrow, ArrowHeadStyle
from iceberg.primitives import Blank, Image, Ellipse, Text, Rectangle, BorderPosition, MathTex
import imageio
import numpy as np
from PIL import Image as PImage
from IPython.display import Video
from tqdm.notebook import tqdm
paths = [
"good/10_llama.mp4",
"good/0_bird.mp4",
"good/18_dog.mp4",
"good/21_a raccoon washing dishes.mp4",
]
titles = [
"Compressibility: llama",
"Incompressibility: bird",
"Aesthetic Quality: dog",
"Prompt-Image Alignment: a raccoon washing dishes",
]
all_images = []
for path in paths:
images = np.array(imageio.mimread(path))
images = np.concatenate([images, np.full_like(images, 255)[..., :1]], axis=-1)
if "raccoon" in path:
images = images[::2]
if "bird" in path:
images = images[np.linspace(0, 75, 100, dtype=int)]
# images = images[::4]
all_images.append(images)
del images
all_images = np.array(all_images)
im_width = all_images.shape[3]
im_height = all_images.shape[2]
num_frames = all_images.shape[1]
def anim(x):
# easeInOutQuad
return 2 * x * x if x < 0.5 else 1 - (-2 * x + 2) ** 2 / 2
TWEEN_MULT = 3
NUM_IMAGES = 6
total_width = NUM_IMAGES * im_width
total_anim_width = total_width - im_width
keyframes = [] # each entry is (loc, frame_index) where loc is in [0, 1]
for i in range(num_frames):
if i == num_frames - 1:
keyframes.append((1, i))
else:
for j in range(TWEEN_MULT):
keyframes.append((anim((i * TWEEN_MULT + j) / ((num_frames - 1) * TWEEN_MULT)), i))
keyframes = np.array(keyframes)
image_locations = np.linspace(0, 1, NUM_IMAGES)
image_keyframes = []
for loc in image_locations:
# find the keyframe that is closest to this image location
closest = np.argmin(np.abs(keyframes[:, 0] - loc))
image_keyframes.append((loc, keyframes[closest, 1]))
renderer = Renderer(gpu=False)
def render_one(i, images, title):
frames = []
for j, (loc, fi) in enumerate(keyframes):
image = Image(image=images[int(fi)]).move(loc * total_anim_width, 0)
if j == i:
frames.append(image)
frames.append(Blank(image.bounds, background=Colors.TRANSPARENT))
for j, (loc, fi) in enumerate(image_keyframes):
if loc < keyframes[i, 0]:
image = Image(image=images[int(fi)]).move(loc * total_anim_width, 0)
frames.insert(0, image)
strip = Compose(frames)
outline = Rectangle(
strip.bounds,
border_color=Colors.BLACK,
fill_color=Colors.TRANSPARENT,
border_thickness=10,
border_position=BorderPosition.OUTSIDE,
border_radius=3.0,
)
strip = Anchor([outline, strip])
title_first, title_last = title.split(":")
text_first = Text(title_first + ":", font_style=FontStyle(family="Open Sans", size=60), align=Text.Align.CENTER)
text_last = Text(title_last, font_style=FontStyle(family="Open Sans", size=60, font_style=FontStyle.Style.ITALIC), align=Text.Align.CENTER)
text = text_first.next_to(text_last, Directions.RIGHT)
scene = text.next_to(strip, Directions.DOWN * 20)
return scene
def render_all(i, ellipse_alpha=1):
scene = None
for images, title in zip(all_images, titles):
strip = render_one(i, images, title)
if scene is not None:
scene = scene.next_to(strip, Directions.DOWN * 50)
else:
scene = strip
arrow_pad = 50
arrow = Arrow(
start=(scene.bounds.left + arrow_pad, 0),
end=(scene.bounds.right - arrow_pad, 0),
head_length=40,
line_path_style=PathStyle(color=Colors.BLACK, thickness=20, stroke_cap=StrokeCap.ROUND),
arrow_head_style=ArrowHeadStyle.FILLED_TRIANGLE,
)
ellipse_radius = 25
arrow_line_width = arrow.bounds.width - arrow.children[-1].bounds.width - ellipse_radius
ellipse_pos = arrow.bounds.left + arrow_line_width * keyframes[i, 0]
ellipse = Ellipse(
Bounds(
ellipse_radius,
ellipse_pos - ellipse_radius,
-ellipse_radius,
ellipse_pos + ellipse_radius,
),
border_color=Colors.TRANSPARENT,
fill_color=Color.from_rgba(255, 0, 0, ellipse_alpha * 255),
)
arrow = Anchor([arrow, ellipse])
scene = scene.next_to(arrow, Directions.DOWN * 50)
text = Text("RL training", font_style=FontStyle(family="Open Sans", size=70), align=Text.Align.CENTER)
scene = scene.next_to(text, Directions.DOWN * 10)
scene = scene.pad(100)
renderer.render(scene.scale(0.5), background_color=Colors.WHITE)
return renderer.get_rendered_image()
# display(PImage.fromarray(render_all(0)))
video = list(map(render_all, tqdm(range(len(keyframes)))))
NUM_FADE = 40
for i in range(NUM_FADE):
prog = anim(i / (NUM_FADE - 1))
video.append(render_all(len(keyframes) - 1, ellipse_alpha=1 - prog))
# max_width = max([im.shape[1] for im in video])
# max_height = max([im.shape[0] for im in video])
# video = [np.pad(im, ((0, max_height - im.shape[0]), (0, max_width - im.shape[1]), (0, 0))) for im in video]
imageio.mimwrite("video.mp4", video, fps=60)
Video("video.mp4", embed=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment