-
-
Save sayakpaul/8cac98d7f22399085a060992f411ecbd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datasets | |
import tomesd | |
import torch | |
import wandb | |
from diffusers import StableDiffusionPipeline | |
CONFIG = { | |
"gen": torch.manual_seed(0), | |
"model_id": "runwayml/stable-diffusion-v1-5", | |
"inference_steps": 25, | |
"num_images_per_prompt": 4, | |
"dtype": torch.float16, | |
"resolution": 512, | |
"num_parti_prompts": 100, | |
"challenge": "basic", | |
"seed": 0, | |
"tome_ratio": 0.5, | |
} | |
def log_images(config, wandb_table, prompt, images): | |
images = [ | |
wandb.Image(image, caption=f"{i}: {prompt}") | |
for i, image in enumerate(images) | |
] | |
wandb_table.add_data(config, prompt, images) | |
def main(): | |
wandb.init(project="tomesd-results", config=CONFIG) | |
wandb_table = wandb.Table(columns=["config", "prompt", "images"]) | |
print(f"Loading pipeline with {CONFIG['model_id']}...") | |
pipeline = StableDiffusionPipeline.from_pretrained( | |
CONFIG["model_id"], torch_dtype=CONFIG["dtype"], safety_checker=None | |
).to("cuda") | |
pipeline.set_progress_bar_config(disable=True) | |
print("Loading the Parti prompts dataset...") | |
parti_ds = datasets.load_dataset("nateraw/parti-prompts", split="train") | |
parti_ds = parti_ds.filter(lambda x: x["Challenge"] == "Basic") | |
parti_ds = parti_ds.shuffle(CONFIG["seed"]).select( | |
range(CONFIG["num_parti_prompts"]) | |
) | |
print("Running inference with vanilla pipeline...") | |
for sample in parti_ds: | |
prompt = sample["Prompt"] | |
vanilla_images = pipeline( | |
prompt, | |
height=CONFIG["resolution"], | |
width=CONFIG["resolution"], | |
num_inference_steps=CONFIG["inference_steps"], | |
num_images_per_prompt=CONFIG["num_images_per_prompt"], | |
generator=CONFIG["gen"], | |
).images | |
log_images("vanilla", wandb_table, prompt=prompt, images=vanilla_images) | |
print("Running inference with ToMe...") | |
tomesd.apply_patch(pipeline, ratio=CONFIG["tome_ratio"]) | |
for sample in parti_ds: | |
prompt = sample["Prompt"] | |
tome_images = pipeline( | |
prompt, | |
height=CONFIG["resolution"], | |
width=CONFIG["resolution"], | |
num_inference_steps=CONFIG["inference_steps"], | |
num_images_per_prompt=CONFIG["num_images_per_prompt"], | |
generator=CONFIG["gen"], | |
).images | |
log_images("tome", wandb_table, prompt=prompt, images=tome_images) | |
print("Running inference with xFormers and ToMe...") | |
tomesd.remove_patch(pipeline) | |
pipeline.enable_xformers_memory_efficient_attention() | |
tomesd.apply_patch(pipeline, ratio=CONFIG["tome_ratio"]) | |
for sample in parti_ds: | |
prompt = sample["Prompt"] | |
tome_xformers_images = pipeline( | |
prompt, | |
height=CONFIG["resolution"], | |
width=CONFIG["resolution"], | |
num_inference_steps=CONFIG["inference_steps"], | |
num_images_per_prompt=CONFIG["num_images_per_prompt"], | |
generator=CONFIG["gen"], | |
).images | |
log_images( | |
"tome_xformers", wandb_table, prompt=prompt, images=tome_xformers_images | |
) | |
wandb.log({"results": wandb_table}) | |
wandb.finish() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment