Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created April 24, 2023 07:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/8cac98d7f22399085a060992f411ecbd to your computer and use it in GitHub Desktop.
Save sayakpaul/8cac98d7f22399085a060992f411ecbd to your computer and use it in GitHub Desktop.
import datasets
import tomesd
import torch
import wandb
from diffusers import StableDiffusionPipeline
CONFIG = {
"gen": torch.manual_seed(0),
"model_id": "runwayml/stable-diffusion-v1-5",
"inference_steps": 25,
"num_images_per_prompt": 4,
"dtype": torch.float16,
"resolution": 512,
"num_parti_prompts": 100,
"challenge": "basic",
"seed": 0,
"tome_ratio": 0.5,
}
def log_images(config, wandb_table, prompt, images):
images = [
wandb.Image(image, caption=f"{i}: {prompt}")
for i, image in enumerate(images)
]
wandb_table.add_data(config, prompt, images)
def main():
wandb.init(project="tomesd-results", config=CONFIG)
wandb_table = wandb.Table(columns=["config", "prompt", "images"])
print(f"Loading pipeline with {CONFIG['model_id']}...")
pipeline = StableDiffusionPipeline.from_pretrained(
CONFIG["model_id"], torch_dtype=CONFIG["dtype"], safety_checker=None
).to("cuda")
pipeline.set_progress_bar_config(disable=True)
print("Loading the Parti prompts dataset...")
parti_ds = datasets.load_dataset("nateraw/parti-prompts", split="train")
parti_ds = parti_ds.filter(lambda x: x["Challenge"] == "Basic")
parti_ds = parti_ds.shuffle(CONFIG["seed"]).select(
range(CONFIG["num_parti_prompts"])
)
print("Running inference with vanilla pipeline...")
for sample in parti_ds:
prompt = sample["Prompt"]
vanilla_images = pipeline(
prompt,
height=CONFIG["resolution"],
width=CONFIG["resolution"],
num_inference_steps=CONFIG["inference_steps"],
num_images_per_prompt=CONFIG["num_images_per_prompt"],
generator=CONFIG["gen"],
).images
log_images("vanilla", wandb_table, prompt=prompt, images=vanilla_images)
print("Running inference with ToMe...")
tomesd.apply_patch(pipeline, ratio=CONFIG["tome_ratio"])
for sample in parti_ds:
prompt = sample["Prompt"]
tome_images = pipeline(
prompt,
height=CONFIG["resolution"],
width=CONFIG["resolution"],
num_inference_steps=CONFIG["inference_steps"],
num_images_per_prompt=CONFIG["num_images_per_prompt"],
generator=CONFIG["gen"],
).images
log_images("tome", wandb_table, prompt=prompt, images=tome_images)
print("Running inference with xFormers and ToMe...")
tomesd.remove_patch(pipeline)
pipeline.enable_xformers_memory_efficient_attention()
tomesd.apply_patch(pipeline, ratio=CONFIG["tome_ratio"])
for sample in parti_ds:
prompt = sample["Prompt"]
tome_xformers_images = pipeline(
prompt,
height=CONFIG["resolution"],
width=CONFIG["resolution"],
num_inference_steps=CONFIG["inference_steps"],
num_images_per_prompt=CONFIG["num_images_per_prompt"],
generator=CONFIG["gen"],
).images
log_images(
"tome_xformers", wandb_table, prompt=prompt, images=tome_xformers_images
)
wandb.log({"results": wandb_table})
wandb.finish()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment