Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Created November 11, 2023 17:26
Show Gist options
  • Save rockerBOO/1bb228a06c7d6f03a8c79bc7ff1fa902 to your computer and use it in GitHub Desktop.
Save rockerBOO/1bb228a06c7d6f03a8c79bc7ff1fa902 to your computer and use it in GitHub Desktop.
from torchmetrics.functional.multimodal import clip_score
from functools import partial
import torch
from datasets import load_dataset
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import random
from pathlib import Path
import argparse
import numpy
clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16")
def calculate_clip_score(images, prompts):
images_int = (images * 255).astype("uint8")
clip_score = clip_score_fn(
torch.from_numpy(images_int).permute(0, 3, 1, 2), prompts
).detach()
return round(float(clip_score), 4)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
# numpy.random.seed(worker_seed)
random.seed(worker_seed)
def main(args):
seed = args.seed
torch.manual_seed(seed)
random.seed(seed)
device = torch.device(
args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
)
# model_ckpt = "runwayml/stable-diffusion-v1-5"
model_ckpt = args.pretrained_model_name_or_path
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_ckpt, subfolder="scheduler"
)
sd_pipeline = StableDiffusionPipeline.from_pretrained(
model_ckpt,
torch_dtype=torch.float16,
safety_checker=None,
use_safetensors=True,
scheduler=scheduler,
).to(device)
if args.xformers:
sd_pipeline.enable_xformers_memory_efficient_attention()
if args.ti_embedding_file is not None:
ti_embedding_file = Path(args.ti_embedding_file)
sd_pipeline.load_textual_inversion(
args.ti_embedding_file, weight_name=ti_embedding_file.name
)
if args.lora_file is not None:
# lora_file = "/mnt/900/training/sets/women-2023-11-10-162026-0fd2ee16/women-2023-11-10-162026-0fd2ee16.safetensors"
lora_file = Path(args.lora_file)
sd_pipeline.load_lora_weights(lora_file, weight_name=lora_file.name)
prompts = load_dataset("nateraw/parti-prompts", split="train")
prompts = prompts.shuffle(seed=seed)
sample_prompts = [prompts[i]["Prompt"] for i in range(50)]
for sample_prompt in sample_prompts:
print(sample_prompt)
images = []
batch_size = 5
for i in range(len(sample_prompts) // batch_size):
print(i * batch_size, i * batch_size + batch_size)
images.append(
sd_pipeline(
sample_prompts[i * batch_size : i * batch_size + batch_size],
num_images_per_prompt=1,
num_inference_steps=15,
output_type="np",
generator=torch.manual_seed(seed),
).images
)
sd_clip_score = calculate_clip_score(numpy.concatenate(images), sample_prompts)
print(f"CLIP score: {sd_clip_score}")
# CLIP score: 35.7038
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--seed", type=int, default=1234, help="Seed for random and torch"
)
argparser.add_argument(
"--pretrained_model_name_or_path",
default="runwayml/stable-diffusion-v1-5",
help="Model to load",
)
argparser.add_argument(
"--lora_file",
default=None,
help="Lora model file to load",
)
argparser.add_argument(
"--ti_embedding_file",
default=None,
help="Textual inversion file to load",
)
argparser.add_argument("--xformers", action="store_true", help="Use XFormers")
argparser.add_argument("--device", default=None, help="Set device to use")
args = argparser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment