Skip to content

Instantly share code, notes, and snippets.

@htoyryla
Created October 10, 2022 15:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save htoyryla/5ae731331bc703c8293302312c9f0093 to your computer and use it in GitHub Desktop.
Save htoyryla/5ae731331bc703c8293302312c9f0093 to your computer and use it in GitHub Desktop.
Stable diffusion text2image assigning combined embeddings to each UNet block
# stable diffusion tool
# @htoyryla github twitter instagram
# requires diffusers 0.3.0 and a trained model
# relies heavily on code from https://github.com/huggingface/diffusers
# neurokuvatreenit, stable diffusion, example 1b using LDM scheduler and 1c saving image at each iteration
# 1d: multiple subprompts with weights
# 1e: estimate final result at each iteration
# 1e2: updated code for diffusers 0.4.0
''' ''
1mp3: experimental,
requires modified Unet from https://gist.github.com/htoyryla/dc8c12e3c2bc3543dc5679d56e30c532
For assigning weighted subprompts to blocks, use prompt like
"something here first:10 / more stuff here:15 / still something important:35 ; 123 | something else first:10 / some more stuff here:15 ; 0456 "
where
| embedding separator
/ subprompt separator
:70 relative weight of subprompt
;012 assign this embedding to blocks 0,1, and 2
'''
import torch
from torchvision.utils import save_image
from torchvision import transforms
import torch.nn.functional as F
from torch import autocast
import PIL
from PIL import Image
import numpy as np
from tqdm import tqdm
import os
import argparse
import inspect
import sys
# we don't use the readymade pipelines so we need to import the modules for VAE, UNET, scheduler and CLIP
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from transformers import CLIPTextModel, CLIPTokenizer
# parse omput params
parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str, default="", help='text prompt')
parser.add_argument('--model', type=str, default="./stable-diffusion-v1-4", help='path to sd model')
parser.add_argument('--steps', type=int, default=50, help='diffusion steps')
parser.add_argument('--g', type=float, default=7.5, help='guidance level')
parser.add_argument('--dir', type=str, default="out1", help='base directory for storing images')
parser.add_argument('--name', type=str, default="test", help='basename for storing images')
parser.add_argument('--imageSize', type=int, default=512, help='image size')
parser.add_argument('--h', type=int, default=0, help='image height')
parser.add_argument('--w', type=int, default=0, help='image width')
parser.add_argument('--slices', type=int, default=2, help='attention slices')
parser.add_argument('--seed', type=int, default=0, help='manual seed')
parser.add_argument('--saveiters', action="store_true", help='save intermediate images')
parser.add_argument('--log', type=str, default="sdruns.log", help='path to log file')
raw_args = " ".join(sys.argv)
opt = parser.parse_args()
# settings
num_inference_steps = opt.steps #500
device = "cuda"
if opt.h == 0:
opt.h = opt.imageSize
if opt.w == 0:
opt.w = opt.imageSize
name = opt.name
steps = opt.steps
bs = 1
text = opt.text
guidance_scale = opt.g
def parse_inner(text):
# prepare prompt: split into subprompts and their weights
plist = [] # a list for subprompts
wlist = [] # a list for their weights
wsum = 0
#separate assignments first
parts = text.split(";")
assigns = parts[1].strip()
text = parts[0]
parts = text.split("/") # split into subprompts at each /
print(parts)
# separate text and weight for each subprompt
for p in parts:
ps = p.split(":")
plist.append(ps[0].strip())
w = float(ps[1])
wlist.append(w)
wsum += w
# normalize weights
for i in range(0, len(wlist)):
wlist[i] = wlist[i] / wsum
return plist, wlist, assigns
def parse_outer(text):
tlist = text.split("|")
return tlist
# utility methods
def numpy_to_pil(images):
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
to_tensor_tfm = transforms.ToTensor()
def pil_to_latent(input_im):
with torch.no_grad():
latent = vae.encode(to_tensor_tfm(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
return 0.18215 * latent.mode() # or .mean or .sample
# load model(s)
pretrained_model_name_or_path = opt.model #"./textual_inversion_set2"
use_auth_token = True
# VAE the imagemaker
print("loadng VAE...")
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder="vae", use_auth_token=use_auth_token
)
vae.eval()
vae.cuda()
del vae.encoder
#UNET the denoiserprint("loadng VAE...")
print("loadng UNET...")
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet", use_auth_token=use_auth_token
)
unet.eval()
unet.cuda()
# text encoder
print("loadng CLIP...")
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=use_auth_token) #.cuda()
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=use_auth_token).cuda()
# Scheduler the noise manager
print("setting up scheduler...")
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
scheduler.set_timesteps(num_inference_steps)
eta = 0
accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# attention slicing to save ram
slice_size = unet.config.attention_head_dim // opt.slices
unet.set_attention_slice(slice_size)
# set up random number gen
if opt.seed != 0:
seed = opt.seed
else:
seed = torch.Generator().seed()
generator = torch.Generator(device=device).manual_seed(seed)
print("Seed:" + str(generator.initial_seed()))
# save command into log
if opt.log != "":
if "--seed" not in raw_args:
raw_args += " --seed "+str(seed) # add current seed
raw_args = "python "+raw_args+"\n"
with open(os.path.join(opt.dir, opt.log), "a+") as text_file:
text_file.write(raw_args)
# initialize latents randomly
latents = torch.randn(
(bs, unet.in_channels, opt.h // 8, opt.w // 8), generator = generator, device = device
)
latents = latents * scheduler.init_noise_sigma
#latents = latents * scheduler.sigmas[0]
print("Latents shape",latents.shape)
latents = latents.to(device)
# text to tokens
# encode empty prompt
#tokens_length = text_tokens.input_ids.shape[-1]
tokens_length = tokenizer.model_max_length
uncond_tokens = tokenizer(
[""] * bs, padding="max_length", max_length=tokens_length, return_tensors="pt"
)
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_tokens.input_ids.to(device))[0]
empty_emb = torch.cat([uncond_embeddings, uncond_embeddings])
# make placeholder for 7 embeddings and fill with uncond embeddings
prep_embeddings = empty_emb.unsqueeze(0).repeat(7,1,1,1)
# parse text into n prompt sets
tlist = parse_outer(text)
for t in tlist:
plist, wlist, assign = parse_inner(t)
# we process a list of all subprompts at the same time
text_tokens = tokenizer(plist, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
print("Tokens shape",text_tokens.input_ids.shape)
# tokens to embedding
with torch.no_grad():
text_embeddings = text_encoder(text_tokens.input_ids.to(device))[0]
print("Text embeddings shape ",text_embeddings.shape)
# now we have embeddings of all subprompts, then if there is more than one, calculate their weighted average
pn = text_embeddings.shape[0]
if pn > 1:
tembs = torch.zeros_like(text_embeddings)[0].unsqueeze(0)
i = 0
for temb in text_embeddings:
tembs = tembs + wlist[i] * temb
i += 1
comb_embeddings = tembs.detach()
print("Text embeddings shape after combining",comb_embeddings.shape)
# assign to blocks
for emb in comb_embeddings:
temb = torch.cat([uncond_embeddings, emb.unsqueeze(0)])
for c in assign:
print(c)
prep_embeddings[int(c)] = temb
print(prep_embeddings.shape)
# save some ram
del text_encoder
torch.cuda.empty_cache()
j = 0
# start diffusing
with torch.no_grad():
for i, t in tqdm(enumerate(scheduler.timesteps)):
# prepare current latent for UNET
latent_model_input = torch.cat([latents] * 2)
# adjust latents according to sigmas (current noise level)
# sigma no longer needed here in 0.4.0 but we use it later
sigma = scheduler.sigmas[i]
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# estimate the noise
noise_pred = unet(latent_model_input, t, encoder_hidden_states=prep_embeddings).sample.detach()
# adjust noise estimate for text guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# estimate denoised latent
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample.detach()
if opt.saveiters:
# estimate final image from current state
los = latents - sigma * noise_pred
# save an image from current latents
lats_ = 1 / 0.18215 * los.detach()
image = vae.decode(lats_.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
image = numpy_to_pil(image)[0]
image.save(opt.dir+os.sep+name +"-t"+str(j)+".png")
j += 1
del unet
# now we have final latent, let's decode the image and save it
latents = 1 / 0.18215 * latents.detach()
image = vae.decode(latents.to(vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
image = numpy_to_pil(image)[0]
image.save(opt.dir+os.sep+name +".png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment