Skip to content

Instantly share code, notes, and snippets.

@deep-diver
Created January 29, 2023 15:28
Show Gist options
  • Save deep-diver/aa5750692589a92abfe14e2f28954a48 to your computer and use it in GitHub Desktop.
Save deep-diver/aa5750692589a92abfe14e2f28954a48 to your computer and use it in GitHub Desktop.
run a number of experiments based on different schedulers for SD
import wandb
import PIL
import matplotlib.pyplot as plt
import torch
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
from diffusers import PNDMScheduler
from diffusers import LMSDiscreteScheduler
model_id = "chansung/dreambooth-dog-to-kerascv_sd_diffusers_pipeline"
prompt = "a photo of sks dog in a colorful bucket"
num_inference_steps_list = [25, 50, 75, 100]
guidance_scale_list = [7.5, 15, 30]
def load_pipeline(model_id):
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)
return pipe
def make_experiment_configs(pipe):
scheduler_list = []
for base_scheduler in [
PNDMScheduler.from_config(pipe.scheduler.config.copy()),
DDIMScheduler.from_config(pipe.scheduler.config.copy()),
LMSDiscreteScheduler.from_config(pipe.scheduler.config.copy())
]:
scheduler_list.append(base_scheduler)
return scheduler_list
def run_experiments(pipe, scheduler_list):
print(scheduler_list)
for scheduler in scheduler_list:
pipe.scheduler = scheduler
wandb.init(
project="SD-Scheduler-Explore",
name=pipe.scheduler.config['_class_name']
)
print(pipe.scheduler.config['_class_name'])
print(scheduler.config['_class_name'])
for num_inference_steps in num_inference_steps_list:
for guidance_scale in guidance_scale_list:
images = pipe(
prompt=prompt,
num_images_per_prompt=8,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale).images
wandb.log(
{
f"num_steps@{num_inference_steps}-ugs@{guidance_scale}": [
wandb.Image(
image, caption=f"{i}: {prompt}"
)
for i, image in enumerate(images)
]
}
)
wandb.finish()
def run():
pipe = load_pipeline(model_id)
scheduler_list = make_experiment_configs(pipe)
run_experiments(pipe, scheduler_list)
if __name__ == "__main__":
# args = parse_args()
# run(args)
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment