Last active
February 6, 2023 03:59
-
-
Save deep-diver/0a2deb2cd369ab8c1bf3ee12f47d272a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import wandb | |
import PIL | |
import matplotlib.pyplot as plt | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from diffusers import DDIMScheduler | |
from diffusers import PNDMScheduler | |
from diffusers import LMSDiscreteScheduler | |
model_id = "chansung/dreambooth-dog-to-kerascv_sd_diffusers_pipeline" | |
prompt = "a photo of sks dog in a colorful bucket" | |
num_inference_steps_list = [25, 50, 75, 100] | |
guidance_scale_list = [7.5, 15, 30] | |
scheduler_configs = { | |
"DDIMScheduler": { | |
"beta_value": [ | |
[0.000001, 0.02], | |
[0.000005, 0.02], | |
[0.00001, 0.02], | |
[0.00005, 0.02], | |
[0.0001, 0.02], | |
[0.0005, 0.02] | |
], | |
"beta_schedule": [ | |
"linear", | |
"scaled_linear", | |
"squaredcos_cap_v2" | |
], | |
"clip_sample": [True, False], | |
"set_alpha_to_one": [True, False], | |
"prediction_type": [ | |
"epsilon", | |
"sample", | |
"v_prediction" | |
] | |
}, | |
"PNDMScheduler": { | |
"beta_value": [ | |
[0.000001, 0.02], | |
[0.000005, 0.02], | |
[0.00001, 0.02], | |
[0.00005, 0.02], | |
[0.0001, 0.02], | |
[0.0005, 0.02] | |
], | |
"beta_schedule": [ | |
"linear", | |
"scaled_linear", | |
"squaredcos_cap_v2" | |
], | |
"skip_prk_steps": [True, False], | |
"set_alpha_to_one": [True, False], | |
"prediction_type": [ | |
"epsilon", | |
"v_prediction" | |
] | |
}, | |
"LMSDiscreteScheduler": { | |
"beta_value": [ | |
[0.000001, 0.02], | |
[0.000005, 0.02], | |
[0.00001, 0.02], | |
[0.00005, 0.02], | |
[0.0001, 0.02], | |
[0.0005, 0.02] | |
], | |
"beta_schedule": [ | |
"linear", | |
"scaled_linear", | |
], | |
"prediction_type": [ | |
"epsilon", | |
"sample", | |
"v_prediction" | |
] | |
} | |
} | |
def load_pipeline(model_id): | |
device = "cuda" | |
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) | |
pipe = pipe.to(device) | |
return pipe | |
def make_experiment_configs(pipe): | |
scheduler_list = [] | |
for base_scheduler in [ | |
"DDIMScheduler", | |
"PNDMScheduler", | |
"LMSDiscreteScheduler" | |
]: | |
if base_scheduler == "DDIMScheduler": | |
configs = scheduler_configs[base_scheduler] | |
for comb in itertools.product(configs['beta_value'], | |
configs['beta_schedule'], | |
configs['clip_sample'], | |
configs['set_alpha_to_one'], | |
configs['prediction_type']): | |
new_scheduler = DDIMScheduler.from_config(pipe.scheduler.config.copy()) | |
new_scheduler.config["beta_start"] = comb[0][0] | |
new_scheduler.config["beta_end"] = comb[0][1] | |
new_scheduler.config["beta_schedule"] = comb[1] | |
new_scheduler.config["clip_sample"] = comb[2] | |
new_scheduler.config["set_alpha_to_one"] = comb[3] | |
new_scheduler.config["prediction_type"] = comb[4] | |
scheduler_list.append(new_scheduler) | |
elif base_scheduler == "PNDMScheduler": | |
configs = scheduler_configs[base_scheduler] | |
for comb in itertools.product(configs['beta_value'], | |
configs['beta_schedule'], | |
configs['skip_prk_steps'], | |
configs['set_alpha_to_one'], | |
configs['prediction_type']): | |
new_scheduler = PNDMScheduler.from_config(pipe.scheduler.config.copy()) | |
new_scheduler.config["beta_start"] = comb[0][0] | |
new_scheduler.config["beta_end"] = comb[0][1] | |
new_scheduler.config["beta_schedule"] = comb[1] | |
new_scheduler.config["skip_prk_steps"] = comb[2] | |
new_scheduler.config["set_alpha_to_one"] = comb[3] | |
new_scheduler.config["prediction_type"] = comb[4] | |
scheduler_list.append(new_scheduler) | |
elif base_scheduler == "LMSDiscreteScheduler": | |
configs = scheduler_configs[base_scheduler] | |
for comb in itertools.product(configs['beta_value'], | |
configs['beta_schedule'], | |
configs['prediction_type']): | |
new_scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config.copy()) | |
new_scheduler.config["beta_start"] = comb[0][0] | |
new_scheduler.config["beta_end"] = comb[0][1] | |
new_scheduler.config["beta_schedule"] = comb[1] | |
new_scheduler.config["prediction_type"] = comb[2] | |
scheduler_list.append(new_scheduler) | |
return scheduler_list | |
def run_experiments(pipe, scheduler_list): | |
print(scheduler_list) | |
for scheduler in scheduler_list: | |
pipe.scheduler = scheduler | |
s_name = pipe.scheduler.config["_class_name"] | |
s_bs = pipe.scheduler.config["beta_start"] | |
s_be = pipe.scheduler.config["beta_end"] | |
s_bsch = pipe.scheduler.config["beta_schedule"] | |
s_alpha = pipe.scheduler.config["set_alpha_to_one"] | |
s_pt = pipe.scheduler.config["prediction_type"] | |
if isinstance(scheduler, DDIMScheduler): | |
s_clip = pipe.scheduler.config["clip_sample"] | |
s_id = f"{s_name}-b@{s_bs}~{s_be}w/{s_bsch}-c@{s_clip}-a@{s_alpha}-pt@{s_pt}" | |
elif isinstance(scheduler, PNDMScheduler): | |
s_skip_prk = pipe.scheduler.config["skip_prk_steps"] | |
s_id = f"{s_name}-b@{s_bs}~{s_be}w/{s_bsch}-prk@{s_skip_prk}-a@{s_alpha}-pt@{s_pt}" | |
elif isinstance(scheduler, LMSDiscreteScheduler): | |
s_id = f"{s_name}-b@{s_bs}~{s_be}w/{s_bsch}-pt@{s_pt}" | |
wandb.init( | |
project="SD-Scheduler-Explore", | |
name=s_id | |
) | |
print(pipe.scheduler.config["_class_name"]) | |
print(scheduler.config["_class_name"]) | |
for num_inference_steps in num_inference_steps_list: | |
for guidance_scale in guidance_scale_list: | |
images = pipe( | |
prompt=prompt, | |
num_images_per_prompt=8, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale).images | |
wandb.log( | |
{ | |
f"num_steps@{num_inference_steps}-ugs@{guidance_scale}": [ | |
wandb.Image( | |
image, caption=f"{i}: {prompt}" | |
) | |
for i, image in enumerate(images) | |
] | |
} | |
) | |
wandb.finish() | |
def run(): | |
pipe = load_pipeline(model_id) | |
scheduler_list = make_experiment_configs(pipe) | |
run_experiments(pipe, scheduler_list) | |
if __name__ == "__main__": | |
# args = parse_args() | |
# run(args) | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment