Created
January 29, 2023 15:28
-
-
Save deep-diver/aa5750692589a92abfe14e2f28954a48 to your computer and use it in GitHub Desktop.
run a number of experiments based on different schedulers for SD
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wandb | |
import PIL | |
import matplotlib.pyplot as plt | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from diffusers import DDIMScheduler | |
from diffusers import PNDMScheduler | |
from diffusers import LMSDiscreteScheduler | |
model_id = "chansung/dreambooth-dog-to-kerascv_sd_diffusers_pipeline" | |
prompt = "a photo of sks dog in a colorful bucket" | |
num_inference_steps_list = [25, 50, 75, 100] | |
guidance_scale_list = [7.5, 15, 30] | |
def load_pipeline(model_id): | |
device = "cuda" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
return pipe | |
def make_experiment_configs(pipe): | |
scheduler_list = [] | |
for base_scheduler in [ | |
PNDMScheduler.from_config(pipe.scheduler.config.copy()), | |
DDIMScheduler.from_config(pipe.scheduler.config.copy()), | |
LMSDiscreteScheduler.from_config(pipe.scheduler.config.copy()) | |
]: | |
scheduler_list.append(base_scheduler) | |
return scheduler_list | |
def run_experiments(pipe, scheduler_list): | |
print(scheduler_list) | |
for scheduler in scheduler_list: | |
pipe.scheduler = scheduler | |
wandb.init( | |
project="SD-Scheduler-Explore", | |
name=pipe.scheduler.config['_class_name'] | |
) | |
print(pipe.scheduler.config['_class_name']) | |
print(scheduler.config['_class_name']) | |
for num_inference_steps in num_inference_steps_list: | |
for guidance_scale in guidance_scale_list: | |
images = pipe( | |
prompt=prompt, | |
num_images_per_prompt=8, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale).images | |
wandb.log( | |
{ | |
f"num_steps@{num_inference_steps}-ugs@{guidance_scale}": [ | |
wandb.Image( | |
image, caption=f"{i}: {prompt}" | |
) | |
for i, image in enumerate(images) | |
] | |
} | |
) | |
wandb.finish() | |
def run(): | |
pipe = load_pipeline(model_id) | |
scheduler_list = make_experiment_configs(pipe) | |
run_experiments(pipe, scheduler_list) | |
if __name__ == "__main__": | |
# args = parse_args() | |
# run(args) | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment