Skip to content

Instantly share code, notes, and snippets.

@rockerBOO
Created November 12, 2023 00:13
Show Gist options
  • Save rockerBOO/ba88c11202cf12c8dfc3047a81dc83c9 to your computer and use it in GitHub Desktop.
Save rockerBOO/ba88c11202cf12c8dfc3047a81dc83c9 to your computer and use it in GitHub Desktop.
import torch
from datasets import Dataset, Image, load_dataset
from torchmetrics.image.fid import FrechetInceptionDistance
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
from torchvision import transforms
import random
from matplotlib import pyplot as plt
import argparse
from pathlib import Path
import numpy
import math
# from torchvision.transforms import v2
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
# numpy.random.seed(worker_seed)
random.seed(worker_seed)
def generate_fake_images(sd_pipeline, sample_prompts, args):
g = torch.Generator()
g.manual_seed(args.seed)
print("Sample prompts:")
for sample_prompt in sample_prompts:
print(f"\t{sample_prompt}")
image_filenames = []
batch_size = args.batch_size or 4
for e, i in enumerate(range(len(sample_prompts) // batch_size)):
images = sd_pipeline(
sample_prompts[i * batch_size : i * batch_size + batch_size],
# sample_prompts,
num_images_per_prompt=1,
num_inference_steps=args.num_inference_steps or 15,
# output_type="np",
generator=g,
).images
for pi, (image, prompt) in enumerate(
zip(images, sample_prompts[i * batch_size : i * batch_size + batch_size])
):
cleaned_prompt = prompt.replace(" ", "-")
filename = f"./tmp/{cleaned_prompt}-{e+i+pi}.png"
image.save(filename)
image_filenames.append({"image": filename})
# TODO we should save these images to the disk
return image_filenames
def setup_sd_pipeline(args):
model_ckpt = args.pretrained_model_name_or_path
print("Using sampler scheduler: DPMSolverMultistepScheduler")
scheduler = DPMSolverMultistepScheduler.from_pretrained(
model_ckpt, subfolder="scheduler"
)
print(f"Using SD model: {model_ckpt}")
sd_pipeline = StableDiffusionPipeline.from_pretrained(
model_ckpt,
torch_dtype=torch.float16,
safety_checker=None,
use_safetensors=True,
scheduler=scheduler,
)
if args.xformers:
print("Using XFormers")
sd_pipeline.enable_xformers_memory_efficient_attention()
if args.ti_embedding_file is not None:
ti_embedding_file = Path(args.ti_embedding_file)
print(f"Using TI Embedding: {ti_embedding_file.name}")
sd_pipeline.load_textual_inversion(
args.ti_embedding_file, weight_name=ti_embedding_file.name
)
if args.lora_file is not None:
lora_file = Path(args.lora_file)
print(f"Using LoRA: {lora_file.name}")
sd_pipeline.load_lora_weights(lora_file, weight_name=lora_file.name)
return sd_pipeline
TRANSFORMS = transforms.Compose(
[
transforms.RandomResizedCrop(size=(299, 299), antialias=True),
transforms.PILToTensor(),
]
)
def collate_fn(data):
# li = [T(d["image"]) for d in data if d["image"].mode == "RGB"]
# if len(li) == 0:
# x = torch.empty((0, 3, 1, 1), dtype=torch.uint8)
# else:
# x = torch.stack(li)
# print(x.shape)
# return x
return torch.stack([TRANSFORMS(d["image"]) for d in data])
def filter_invalid_images(d):
# only supporting RGB images
return d["image"].mode == "RGB"
def process_batch(batch, fid_model, real):
fid_model.update(batch.to("cuda"), real=real)
def load_fake_dir(fake_data_dir, args):
g = torch.Generator()
g.manual_seed(args.seed)
fake_ds = load_dataset("imagefolder", data_dir=fake_data_dir, split="train").filter(
filter_invalid_images
)
fake_dataloader = torch.utils.data.DataLoader(
fake_ds["train"],
batch_size=2,
worker_init_fn=seed_worker,
generator=g,
collate_fn=collate_fn,
)
return fake_ds, fake_dataloader
def main(args):
seed = args.seed
torch.manual_seed(seed)
g = torch.Generator()
g.manual_seed(args.seed)
device = torch.device(
args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")
)
metric = FrechetInceptionDistance()
metric.to(device)
# features = Features({"image": Image()})
real_ds = load_dataset(
"imagefolder", data_dir=args.real_data_dir, split="train"
).filter(filter_invalid_images)
real_dataloader = torch.utils.data.DataLoader(
real_ds,
batch_size=2,
worker_init_fn=seed_worker,
generator=g,
collate_fn=collate_fn,
)
if args.fake_data_dir:
fake_ds, fake_dataloader = load_fake_dir(args.fake_data_dir)
else:
sd_pipeline = setup_sd_pipeline(args).to(device)
prompts = load_dataset("nateraw/parti-prompts", split="train")
prompts = prompts.shuffle(seed=seed)
sample_prompts = [prompts[i]["Prompt"] for i in range(len(real_ds))]
images = generate_fake_images(sd_pipeline, sample_prompts, args)
fake_ds = Dataset.from_list(images)
if args.save_fake_images:
# Load the images and cast them as a Image
fake_ds = fake_ds.cast_column("image", Image())
fake_dataloader = torch.utils.data.DataLoader(
fake_ds,
batch_size=2,
worker_init_fn=seed_worker,
generator=g,
collate_fn=collate_fn,
)
print(f"Real: {len(real_ds)} batches: {len(real_dataloader)}")
print(f"Fake: {len(fake_ds)} batches: {len(fake_dataloader)}")
for i, images in enumerate(real_dataloader):
process_batch(images, metric, real=True)
for i, images in enumerate(fake_dataloader):
process_batch(images, metric, real=False)
metric.set_dtype(torch.float64)
fid = metric.compute()
if args.save_plot:
fig, ax = metric.plot(fid)
plt.savefig("fid.png")
print(f"FID: {fid.item()}")
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--seed", type=int, default=1234, help="Seed for random and torch"
)
argparser.add_argument(
"--pretrained_model_name_or_path",
default="runwayml/stable-diffusion-v1-5",
help="Model to load",
)
argparser.add_argument(
"--lora_file",
default=None,
help="Lora model file to load",
)
argparser.add_argument(
"--ti_embedding_file",
default=None,
help="Textual inversion file to load",
)
argparser.add_argument(
"--fake_data_dir",
default=None,
help="Fake data dir with SD generated images",
)
argparser.add_argument(
"--real_data_dir",
required=True,
help="Real images (non AI generated) data dir. Probably your training or validation images",
)
argparser.add_argument(
"--num_inference_steps",
default=15,
help="Number of inference steps for creating fake images",
)
argparser.add_argument("--xformers", action="store_true", help="Use XFormers")
argparser.add_argument("--device", default=None, help="Set device to use")
argparser.add_argument(
"--save_fake_images",
action="store_true",
help="Should we save the fake image samples if we are creating them",
)
argparser.add_argument(
"--save_fake_images_dir",
default="./tmp",
help="Where should we save the fake images to?",
)
argparser.add_argument(
"--batch_size",
default=None,
help="Batch size for creating fake SD images",
)
argparser.add_argument(
"--save_plot",
default=None,
help="Save an image with the plot for FID. Saves to fid.png",
)
args = argparser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment