Created
August 23, 2022 11:55
-
-
Save ditwoo/c4b5e3d77da72d91ec16fcba1e790327 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import inspect | |
import warnings | |
from typing import List, Optional, Union | |
from tqdm.auto import tqdm | |
import torch | |
from torch import autocast | |
from diffusers import StableDiffusionPipeline | |
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | |
class NSFW_Enabled_DiffusionPipeline(StableDiffusionPipeline): | |
@torch.no_grad() | |
def __call__( | |
self, | |
prompt: Union[str, List[str]], | |
height: Optional[int] = 512, | |
width: Optional[int] = 512, | |
num_inference_steps: Optional[int] = 50, | |
guidance_scale: Optional[float] = 7.5, | |
eta: Optional[float] = 0.0, | |
generator: Optional[torch.Generator] = None, | |
output_type: Optional[str] = "pil", | |
**kwargs, | |
): | |
if "torch_device" in kwargs: | |
device = kwargs.pop("torch_device") | |
warnings.warn( | |
"`torch_device` is deprecated as an input argument to `__call__` and will be removed in v0.3.0." | |
" Consider using `pipe.to(torch_device)` instead." | |
) | |
# Set device as before (to be removed in 0.3.0) | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.to(device) | |
if isinstance(prompt, str): | |
batch_size = 1 | |
elif isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
if height % 8 != 0 or width % 8 != 0: | |
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") | |
# get prompt text embeddings | |
text_input = self.tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0] | |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) | |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` | |
# corresponds to doing no classifier free guidance. | |
do_classifier_free_guidance = guidance_scale > 1.0 | |
# get unconditional embeddings for classifier free guidance | |
if do_classifier_free_guidance: | |
max_length = text_input.input_ids.shape[-1] | |
uncond_input = self.tokenizer( | |
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
) | |
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
# For classifier free guidance, we need to do two forward passes. | |
# Here we concatenate the unconditional and text embeddings into a single batch | |
# to avoid doing two forward passes | |
text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
# get the intial random noise | |
latents = torch.randn( | |
(batch_size, self.unet.in_channels, height // 8, width // 8), | |
generator=generator, | |
device=self.device, | |
) | |
# set timesteps | |
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) | |
extra_set_kwargs = {} | |
if accepts_offset: | |
extra_set_kwargs["offset"] = 1 | |
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) | |
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
latents = latents * self.scheduler.sigmas[0] | |
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature | |
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. | |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 | |
# and should be between [0, 1] | |
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) | |
extra_step_kwargs = {} | |
if accepts_eta: | |
extra_step_kwargs["eta"] = eta | |
for i, t in tqdm(enumerate(self.scheduler.timesteps)): | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
sigma = self.scheduler.sigmas[i] | |
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) | |
# predict the noise residual | |
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | |
# perform guidance | |
if do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
# compute the previous noisy sample x_t -> x_t-1 | |
if isinstance(self.scheduler, LMSDiscreteScheduler): | |
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] | |
else: | |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] | |
# scale and decode the image latents with vae | |
latents = 1 / 0.18215 * latents | |
image = self.vae.decode(latents) | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.cpu().permute(0, 2, 3, 1).numpy() | |
# # run safety checker | |
# safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) | |
# image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) | |
if output_type == "pil": | |
image = self.numpy_to_pil(image) | |
return {"sample": image, "nsfw_content_detected": False} | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--seed", type=int, default=123) | |
parser.add_argument("--steps", type=int, default=50) | |
parser.add_argument("--prompt", type=str, required=True) | |
args = vars(parser.parse_args()) | |
seed = args["seed"] | |
steps = args["steps"] | |
prompt = args["prompt"] | |
model_id = "CompVis/stable-diffusion-v1-4" | |
device = "cuda" | |
pipe = NSFW_Enabled_DiffusionPipeline.from_pretrained(model_id, use_auth_token=True) | |
pipe = pipe.to(device) | |
with autocast("cuda"): | |
torch.manual_seed(seed) | |
image = pipe(prompt, num_inference_steps=steps)["sample"][0] | |
image.save("outputs/{}-{}-{}.png".format( | |
str(seed), | |
str(steps), | |
"_".join(prompt.split()), | |
)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment