Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active January 26, 2023 06:48
Show Gist options
  • Save sayakpaul/0d83d7fd7c3939ce2ddc2292b6d4f173 to your computer and use it in GitHub Desktop.
Save sayakpaul/0d83d7fd7c3939ce2ddc2292b6d4f173 to your computer and use it in GitHub Desktop.
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy("mixed_float16")
import glob
import os
import keras_cv
import numpy as np
import PIL
import wandb
from tqdm import tqdm
def download_unet_params(run_id, run_name) -> str:
run = wandb.init(project="experimentation_images", name=run_name)
run_artifact_id = f"sayakpaul/dreambooth-keras/run_{run_id}_model:v0"
artifact = run.use_artifact(run_artifact_id, type="model")
artifact_dir = artifact.download()
unet_params_path = glob.glob(f"{artifact_dir}/*.h5")[0]
return run, unet_params_path
# Initialize the SD model.
img_height = img_width = 512
sd_model = keras_cv.models.StableDiffusion(
img_width=img_width, img_height=img_height, jit_compile=True
)
# Download run data.
api = wandb.Api()
runs = api.runs("sayakpaul/dreambooth-keras")
# Initialize variables.
num_steps = [25, 50, 75, 100]
num_images_to_gen = 3
caption = "A photo of sks dog in a bucket"
unconditional_guidance_scales = [7.5, 15, 30]
# Generate example results.
for run in tqdm(runs):
run_id = run.id
run_name = run.name
print(f"Generating images for {run_name}.")
new_run, unet_params_path = download_unet_params(run_id, run_name)
sd_model.diffusion_model.load_weights(unet_params_path)
os.makedirs(run_name, exist_ok=True)
for steps in num_steps:
for scale in unconditional_guidance_scales:
images = sd_model.text_to_image(
caption,
batch_size=num_images_to_gen,
num_steps=steps,
unconditional_guidance_scale=scale,
)
wandb.log(
{
f"num_steps@{steps}-ugs@{scale}": [
wandb.Image(
PIL.Image.fromarray(image), caption=f"{i}: {caption}"
)
for i, image in enumerate(images)
]
}
)
new_run.finish()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment