Skip to content

Instantly share code, notes, and snippets.

@karpathy
Last active March 29, 2024 09:33
Show Gist options
  • Save karpathy/00103b0037c5aaea32fe1da1af553355 to your computer and use it in GitHub Desktop.
Save karpathy/00103b0037c5aaea32fe1da1af553355 to your computer and use it in GitHub Desktop.
hacky stablediffusion code for generating videos
"""
stable diffusion dreaming
creates hypnotic moving videos by smoothly walking randomly through the sample space
example way to run this script:
$ python stablediffusionwalk.py --prompt "blueberry spaghetti" --name blueberry
to stitch together the images, e.g.:
$ ffmpeg -r 10 -f image2 -s 512x512 -i blueberry/frame%06d.jpg -vcodec libx264 -crf 10 -pix_fmt yuv420p blueberry.mp4
nice slerp def from @xsteenbrugge ty
you have to have access to stablediffusion checkpoints from https://huggingface.co/CompVis
and install all the other dependencies (e.g. diffusers library)
"""
import os
import inspect
import fire
from diffusers import StableDiffusionPipeline
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from time import time
from PIL import Image
from einops import rearrange
import numpy as np
import torch
from torch import autocast
from torchvision.utils import make_grid
# -----------------------------------------------------------------------------
@torch.no_grad()
def diffuse(
pipe,
cond_embeddings, # text conditioning, should be (1, 77, 768)
cond_latents, # image conditioning, should be (1, 4, 64, 64)
num_inference_steps,
guidance_scale,
eta,
):
torch_device = cond_latents.get_device()
# classifier guidance: add the unconditional embedding
max_length = cond_embeddings.shape[1] # 77
uncond_input = pipe.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
# if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
cond_latents = cond_latents * pipe.scheduler.sigmas[0]
# init the scheduler
accepts_offset = "offset" in set(inspect.signature(pipe.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
pipe.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(pipe.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# diffuse!
for i, t in enumerate(pipe.scheduler.timesteps):
# expand the latents for classifier free guidance
latent_model_input = torch.cat([cond_latents] * 2)
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
sigma = pipe.scheduler.sigmas[i]
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# cfg
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(pipe.scheduler, LMSDiscreteScheduler):
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]
else:
cond_latents = pipe.scheduler.step(noise_pred, t, cond_latents, **extra_step_kwargs)["prev_sample"]
# scale and decode the image latents with vae
cond_latents = 1 / 0.18215 * cond_latents
image = pipe.vae.decode(cond_latents)
# generate output numpy image as uint8
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
image = (image[0] * 255).astype(np.uint8)
return image
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
""" helper function to spherically interpolate two arrays v1 v2 """
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
def run(
# --------------------------------------
# args you probably want to change
prompt = "blueberry spaghetti", # prompt to dream about
gpu = 0, # id of the gpu to run on
name = 'blueberry', # name of this project, for the output directory
rootdir = '/home/ubuntu/dreams',
num_steps = 200, # number of steps between each pair of sampled points
max_frames = 10000, # number of frames to write and then exit the script
num_inference_steps = 50, # more (e.g. 100, 200 etc) can create slightly better images
guidance_scale = 7.5, # can depend on the prompt. usually somewhere between 3-10 is good
seed = 1337,
# --------------------------------------
# args you probably don't want to change
quality = 90, # for jpeg compression of the output images
eta = 0.0,
width = 512,
height = 512,
weights_path = "/home/ubuntu/stable-diffusion-v1-3-diffusers",
# --------------------------------------
):
assert torch.cuda.is_available()
assert height % 8 == 0 and width % 8 == 0
torch.manual_seed(seed)
torch_device = f"cuda:{gpu}"
# init the output dir
outdir = os.path.join(rootdir, name)
os.makedirs(outdir, exist_ok=True)
# init all of the models and move them to a given GPU
lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
pipe = StableDiffusionPipeline.from_pretrained(weights_path, scheduler=lms, use_auth_token=True)
pipe.unet.to(torch_device)
pipe.vae.to(torch_device)
pipe.text_encoder.to(torch_device)
# get the conditional text embeddings based on the prompt
text_input = pipe.tokenizer(prompt, padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
cond_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]
# sample a source
init1 = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=torch_device)
# iterate the loop
frame_index = 0
while frame_index < max_frames:
# sample the destination
init2 = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=torch_device)
for i, t in enumerate(np.linspace(0, 1, num_steps)):
init = slerp(float(t), init1, init2)
print("dreaming... ", frame_index)
with autocast("cuda"):
image = diffuse(pipe, cond_embeddings, init, num_inference_steps, guidance_scale, eta)
im = Image.fromarray(image)
outpath = os.path.join(outdir, 'frame%06d.jpg' % frame_index)
im.save(outpath, quality=quality)
frame_index += 1
init1 = init2
if __name__ == '__main__':
fire.Fire(run)
@patil-suraj
Copy link

For faster inference, we can wrap the call to diffuse in torch.autocast so the inference will run in half-precision. For example

from torch import autocast

with autocast("cuda"):
     image = diffuse(text_embeddings, init, guidance_scale=10.0)

@karpathy
Copy link
Author

For faster inference, we can wrap the call to diffuse in torch.autocast so the inference will run in half-precision. For example

from torch import autocast

with autocast("cuda"):
     image = diffuse(text_embeddings, init, guidance_scale=10.0)

yes i dropped this accidentally, added, ty

@Chandulbc
Copy link

We can automate or just create video

@nateraw
Copy link

nateraw commented Aug 18, 2022

I updated this gist so you can walk over different text prompts. You can check it out the revised gist here.

Thanks @karpathy for this starting point ❤️

Here's a video of "blueberry spaghetti" ➡️ "strawberry spaghetti"

berry_good_spaghetti.mp4

@karpathy
Copy link
Author

made some recent edits:

  • I find that the LMSDiscreteScheduler works better, made it default
  • default quality of 75 for jpeg in im.save can be a bit too low, bumping to 90 as default
  • intro more options people may want to play with, esp num_inference_steps, which defaults to 50, but can be bumped up a bit too for slightly higher quality in my experience

@MicZet
Copy link

MicZet commented Aug 21, 2022

Nice!
what is "import fire" ?

@darthdeus
Copy link

Nice! what is "import fire" ?

It's a python library for CLI interfaces, I just ran pip install fire and everything works.

@TillBeemelmanns
Copy link

Here is what I got for the promt "syd mead painting"

Watch the video

https://youtu.be/50MeMvGnm8Q

Thanks @karpathy :)

@Jwrig124
Copy link

I'm getting a "Sorry, we can't find the page you are looking for" when it's trying to load the API. I've had the token authenticated, anyone know how to resolve this?

@DrakenZA
Copy link

@Jwrig124
Copy link

IMAGE ALT TEXT https://www.youtube.com/watch?v=hMG7f6AcrlQ

Hello, could I ask what GPU and VRAM is allocated to CUDA? I can't seem to get anything to work on torch.

@TillBeemelmanns
Copy link

Hello, could I ask what GPU and VRAM is allocated to CUDA? I can't seem to get anything to work on torch.

approximately 10GB of VRAM for 512x512

@magJ
Copy link

magJ commented Aug 26, 2022

Edit: Solved this by updating some libraries pip install --upgrade diffusers transformers scipy, thanks to this comment.

I'm trying to get this to run on a 3080, I'm able to get the basic stable diffusion demos to work using the fp16 model variant

I tried running this script with the fp16 variant of the model, modifying the StableDiffusionPipeline like so:

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=lms, torch_dtype=torch.float16, revision="fp16", use_auth_token=True)

However I get the following error, any ideas how to fix this? Is it possible to run this using the fp16 variant?

Traceback (most recent call last):
  File "stablediffusionwalk.py", line 193, in <module>
    fire.Fire(run)
  File "/home/m/miniconda3/envs/ldm/lib/python3.8/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/m/miniconda3/envs/ldm/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/m/miniconda3/envs/ldm/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "stablediffusionwalk.py", line 166, in run
    cond_embeddings = pipe.text_encoder(text_input.input_ids.to(torch_device))[0] # shape [1, 77, 768]

...

  File "/home/m/miniconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py", line 257, in forward
    attn_output = torch.bmm(attn_probs, value_states)
RuntimeError: expected scalar type Half but found Float

@gordicaleksa
Copy link

gordicaleksa commented Aug 27, 2022

Hi Andrej!

Not sure when you wrote this gist (and maybe the purpose was to be explicit about all components) but all of this can be done much more concisely, e.g. (without the max_frames functionality):

lms = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=lms, use_auth_token=True)
pipe.to(torch_device)
   
source_latent = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=torch_device)
target_latent = torch.randn((1, pipe.unet.in_channels, height // 8, width // 8), device=torch_device)

frame_index = 0
for _, t in enumerate(np.linspace(0, 1, num_steps)):
    init_latent = slerp(float(t), source_latent , target_latent )

    with autocast("cuda"):
        image = pipe(prompt, num_inference_steps=num_inference_steps, latents=init_latent , guidance_scale=guidance_scale)["sample"][0]  
    outpath = os.path.join(outdir, f'frame{frame_index}.jpg')
    image.save(outpath)

    frame_index += 1

You don't even need the diffuse func, that's what pipe is doing behind the scenes (and it supports everything you tried to pass to diffuse func).

Or at least it should work, lol! It's from this notebook but looking at the code I'm not sure latents argument is used...

EDIT: super easy fix to the pipeline call func here:

if 'latents' in kwargs:
            latents = kwargs['latents']
else:
    # get the intial random noise
    latents = torch.randn(
        (batch_size, self.unet.in_channels, height // 8, width // 8),
        generator=generator,
        device=self.device,
    )

Now it's concise and it actually works. 😇

EDIT2: huggingface/diffusers#262 <- according to this the change is actually already integrated you just have to check out the HEAD commit. :))

EDIT3: expanded on this gist here: https://github.com/gordicaleksa/stable_diffusion_playground :)

@allo-
Copy link

allo- commented Aug 31, 2022

@karpathy Under what license can we modify the script?

@atarashansky
Copy link

This script ripped out code from huggingface/diffusers so the license is whatever the source repo is using: https://github.com/huggingface/diffusers/blob/main/LICENSE

@allo-
Copy link

allo- commented Sep 1, 2022

You can move the calculation of cond_embeddings and uncond_embeddings out of diffuse.
This may speed up things a bit. I also tried to delete cond_embeddings, uncond_embeddings and pipe.text_encoder to free up GPU memory, but it didn't get me much memory.

@kartik3ya
Copy link

why do we specifically use slerp? what happens if we linearly interpolate ?

@enzokro
Copy link

enzokro commented Nov 17, 2022

@kartik3ya

why do we specifically use slerp? what happens if we linearly interpolate ?

To get some ideas about why SLERP, I like this post by Ferenc Huszár. It has a fantastic description of some of the issues that pop up.

LERP still "works" to produce outputs. It likely only works because these models have been carefully trained on boatloads of data. SLERP is a more proper and grounded way of doing the interpolation.

For a self-plug, here's a post of mine with a native PyTorch version of SLERP. It builds on the one here and a few other implementations floating around: PyTorch SLERP

@olegchomp
Copy link

Somehow got this error: "ValueError: only one element tensors can be converted to Python scalars"

@IAmCorbin
Copy link

IAmCorbin commented Feb 6, 2023

I've got CUDA installed now and past all the previous errors I was getting with the script, but now I'm getting "CUDA out of memory" error. I only have 6GB of VRAM. I tried lowering the higher and width parameters but it didn't help. Any suggestions for getting this to run?

Edit: running a simple script to generate a single image I also get a memory error unless I add: , torch_dtype=torch.float16

Adding torch_dtype=torch.float16 to this script I now get a new error on this line
cond_latents = pipe.scheduler.step(noise_pred, i, cond_latents, **extra_step_kwargs)["prev_sample"]

only one element tensors can be converted to Python scalars

Edit 2: Got it working by using gordicaleksa's recommendation for a simplified version above! Thanks!
One minor change for it to work for me, ["sample"][0] was invalid. I changed image variable name to pipelineOutput and then the save line changed to: pipelineOutput.images[0].save(outpath)

@devxpy
Copy link

devxpy commented Feb 9, 2023

For faster inference, we can wrap the call to diffuse in torch.autocast so the inference will run in half-precision. For example

from torch import autocast

with autocast("cuda"):
     image = diffuse(text_embeddings, init, guidance_scale=10.0)

yes i dropped this accidentally, added, ty

Interestingly enough, huggingface discourages the use of autocast!

image

@lizc126
Copy link

lizc126 commented Mar 11, 2023

Somehow got this error: "ValueError: only one element tensors can be converted to Python scalars"

Same. Did you manage to solve it?

@Maryammaryam877
Copy link

I am new in this project. would you like to help me. I am facing error in this code of line@
image = diffuse(pipe, cond_embeddings, init, num_inference_steps, guidance_scale, eta)
cond_embeddings out of index 51. eta variable. what should I need to do to overcome this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment