Skip to content

Instantly share code, notes, and snippets.

@anijain2305
Created July 15, 2024 04:24
Show Gist options
  • Save anijain2305/686c2e3fa6f5f5a8796c99e225f1ef9f to your computer and use it in GitHub Desktop.
Save anijain2305/686c2e3fa6f5f5a8796c99e225f1ef9f to your computer and use it in GitHub Desktop.
import time as time_module
import torch
from lumiere_pytorch import MPLumiere
import logging
from denoising_diffusion_pytorch import KarrasUnet
karras_unet = KarrasUnet(
image_size = 256,
dim = 8,
channels = 3,
dim_max = 768
)
lumiere = MPLumiere(
karras_unet,
image_size = 256,
unet_time_kwarg = 'time',
conv_module_names = [
'downs.1',
'ups.1'
],
attn_module_names = [
'mids.0'
],
upsample_module_names = [
'ups.1'
],
downsample_module_names = [
'downs.1'
]
)
USE_CUDA = True
LOGS = ""
EAGER_TIME = None
COMPILED_TIME = None
def timed(fn):
def wrapper(*args, **kwargs):
global LOGS
times = 5
if fn.__name__ == "first_compiled_run":
# it's compiling
times = 1
if times != 1:
# warmup
for i in range(5):
out = fn(*args, **kwargs)
durations = []
for i in range(times):
t0 = time_module.perf_counter()
out = fn(*args, **kwargs)
t1 = time_module.perf_counter()
duration = t1 - t0
LOGS += f"Run {i} of {fn.__name__} took {duration:.4f} seconds\n"
durations.append(duration)
avg = sum(durations)/len(durations)
LOGS += f"==> {fn.__name__} took {avg:.4f} seconds on average of {times} runs\n"
if fn.__name__ == "eager_run":
global EAGER_TIME
EAGER_TIME = avg
elif fn.__name__ == "second_compiled_run":
global COMPILED_TIME
COMPILED_TIME = avg
return out
return wrapper
"""
To squeeze max compile perf, we need to graph break on the
functions that are causing many recompiles
lumiere_pytorch/lumiere.py:511, auto_repeat_tensors_for_time
denoising-diffusion-pytorch/karras_unet.py:127, normalize_weight
should get 1.9x with aot_eager on H100, 40s compile time
"""
with torch.no_grad():
noised_video = torch.randn(2, 3, 8, 256, 256)
time = torch.ones(2,)
if USE_CUDA:
torch.set_default_device("cuda")
noised_video = noised_video.to("cuda")
time = time.to("cuda")
lumiere = lumiere.to("cuda")
# eager
@timed
def eager_run(model, input, time):
return model(input, time = time)
denoised_video = eager_run(lumiere, noised_video, time = time)
assert noised_video.shape == denoised_video.shape
@timed
def first_compiled_run(model, input, time):
return model(input, time = time)
# torch._logging.set_logs(dynamo=logging.INFO)
lumiere.model = torch.compile(lumiere.model, backend="eager")
out1 = first_compiled_run(lumiere, noised_video, time = time)
# model = torch.compile(lumiere, backend="eager")
# out1 = first_compiled_run(model, noised_video, time = time)
@timed
def second_compiled_run(model, input, time):
return model(input, time = time)
frozen_model = torch._dynamo.run(lumiere)
out2 = second_compiled_run(frozen_model, noised_video, time = time)
LOGS += f"speedup: {(EAGER_TIME / COMPILED_TIME):.4f}x"
print(LOGS)
# accuracy
torch.testing.assert_close(denoised_video, out1)
torch.testing.assert_close(denoised_video, out2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment