Skip to content

Instantly share code, notes, and snippets.

@a-r-r-o-w
Last active October 2, 2024 12:26
Show Gist options
  • Save a-r-r-o-w/4d9732d17412888c885480c6521a9897 to your computer and use it in GitHub Desktop.
Save a-r-r-o-w/4d9732d17412888c885480c6521a9897 to your computer and use it in GitHub Desktop.
Demonstrates how to use CogVideoX 2B/5B with Diffusers and TorchAO
#!/bin/bash
compile_flags=("" "--compile")
fuse_qkv_flags=("" "--fuse_qkv")
# quantizations=("fp16" "bf16" "fp8" "fp8_e4m3" "fp8_e5m2" "fp6" "int8wo" "int8dq" "int4dq" "int4wo" "autoquant" "sparsify")
quantizations=("fp16" "bf16" "fp6" "int8wo" "int8dq" "int4dq" "int4wo" "autoquant" "sparsify")
device="cuda"
# Check if completed.txt exists and read it into an array
if [ -f completed.txt ]; then
mapfile -t completed_runs < completed.txt
else
completed_runs=()
fi
for quantization in "${quantizations[@]}"; do
for compile in "${compile_flags[@]}"; do
for fuse_qkv in "${fuse_qkv_flags[@]}"; do
cmd="python3 cogvideox-torchao-benchmark.py $compile $fuse_qkv --dtype $quantization --device $device"
# Check if the command is in the list of completed runs
if [[ " ${completed_runs[@]} " =~ " ${cmd} " ]]; then
echo "Skipping already completed command: $cmd"
continue
fi
echo "Running command: $cmd"
eval $cmd
echo -ne "------------------ Finished executing script ------------------\n\n"
done
done
done
import argparse
import gc
import os
import time
os.environ["TORCH_LOGS"] = "dynamo,output_code,graph_breaks,recompiles"
import torch
import torch.utils.benchmark as benchmark
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler
from diffusers.utils import export_to_video
from tabulate import tabulate
from transformers import T5EncoderModel
from torchao.quantization import (
autoquant,
quantize_,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int4_weight_only,
)
from torchao.sparsity import sparsify_
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8
from torchao.prototype.quant_llm import fp6_llm_weight_only
torch.set_float32_matmul_precision("high")
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_check_all_directions = True
DTYPE_CONVERTER = {
"fp32": lambda module: module.to(dtype=torch.float32),
"fp16": lambda module: module.to(dtype=torch.float16),
"bf16": lambda module: module.to(dtype=torch.bfloat16),
"fp8": lambda module: quantize_to_float8(module, QuantConfig(ActivationCasting.DYNAMIC)),
"fp8_e4m3": lambda module: module.to(dtype=torch.float8_e4m3fn),
"fp8_e5m2": lambda module: module.to(dtype=torch.float8_e5m2),
"fp6": lambda module: quantize_(module, fp6_llm_weight_only()),
"int8wo": lambda module: quantize_(module, int8_weight_only()),
"int8dq": lambda module: quantize_(module, int8_dynamic_activation_int8_weight()),
"int4dq": lambda module: quantize_(module, int8_dynamic_activation_int4_weight()),
"int4wo": lambda module: quantize_(module, int4_weight_only()),
"autoquant": lambda module: autoquant(module, error_on_unseen=False),
"sparsify": lambda module: sparsify_(module, int8_dynamic_activation_int8_semi_sparse_weight()),
}
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def reset_memory(device):
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device):
memory = torch.cuda.memory_allocated(device) / 1024**3
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory=:.3f}")
print(f"{max_memory=:.3f}")
print(f"{max_reserved=:.3f}")
def pretty_print_results(results, precision: int = 6):
def format_value(value):
if isinstance(value, float):
return f"{value:.{precision}f}"
return value
filtered_table = {k: format_value(v) for k, v in results.items()}
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))
def load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv):
# 1. Load pipeline
pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
pipe.set_progress_bar_config(disable=True)
if fuse_qkv:
pipe.fuse_qkv_projections()
# 2. Quantize and compile
if dtype == "autoquant" and compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
# VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30
text_encoder_return = DTYPE_CONVERTER[dtype](pipe.text_encoder)
transformer_return = DTYPE_CONVERTER[dtype](pipe.transformer)
vae_return = None
if dtype in ["fp32", "fp16", "bf16", "fp8_e4m3", "fp8_e5m2"] or quantize_vae:
vae_return = DTYPE_CONVERTER[dtype](pipe.vae)
if text_encoder_return is not None:
pipe.text_encoder = text_encoder_return
if transformer_return is not None:
pipe.transformer = transformer_return
if vae_return is not None:
pipe.vae = vae_return
if dtype != "autoquant" and compile:
pipe.transformer.to(memory_format=torch.channels_last)
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
# VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30
return pipe
def run_inference(pipe):
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(
prompt=prompt,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203
)
return video
def main(dtype, device, quantize_vae, compile, fuse_qkv):
# 1. Load pipeline
# model_id = "THUDM/CogVideoX-5b" # or "THUDM/CogVideoX-2b"
model_id = "THUDM/CogVideoX-5b"
pipe = load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv)
reset_memory(device)
print_memory(device)
torch.cuda.empty_cache()
model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
# 2. Warmup
num_warmups = 2
for _ in range(num_warmups):
video = run_inference(pipe)
# 3. Benchmark
time = benchmark_fn(run_inference, pipe)
print_memory(device)
torch.cuda.empty_cache()
inference_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 3)
# 4. Save results
model_type = "5B" if "5b" in model_id else "2B"
info = {
"model_type": model_type,
"compile": compile,
"fuse_qkv": fuse_qkv,
"quantize_vae": quantize_vae,
"quantization": dtype,
"model_memory": model_memory,
"inference_memory": inference_memory,
"time": time,
}
pretty_print_results(info, precision=3)
export_to_video(
video.frames[0], f"output-quantization_{dtype}-compile_{compile}-fuse_qkv_{fuse_qkv}-{model_type}.mp4", fps=8
)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dtype",
type=str,
default="fp16",
choices=[
"fp32",
"fp16",
"bf16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
"fp6",
"int8wo",
"int8dq",
"int4dq",
"int4wo",
"autoquant",
"sparsify",
],
)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--quantize_vae", action="store_true", default=False)
parser.add_argument("--compile", action="store_true", default=False)
parser.add_argument("--fuse_qkv", action="store_true", default=False)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args.dtype, args.device, args.quantize_vae, args.compile, args.fuse_qkv)
# Install torchao from source and Pytorch Nightly
# Other environments have not yet been tested.
import tempfile
import torch
from diffusers import CogVideoXTransformer3DModel, CogVideoXPipeline
from diffusers.utils import export_to_video
from torchao.quantization import (
quantize_,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_weight_only,
int8_dynamic_activation_int8_weight,
)
# Either "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
model_id = "THUDM/CogVideoX-5b"
# 1. Quantize and save the transformer
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
quantize_(transformer, int8_weight_only())
with tempfile.NamedTemporaryFile() as file:
torch.save(transformer.state_dict(), file)
file.seek(0)
state_dict = torch.load(file, map_location="cpu")
# 2. Create new model and load quantized state dict
transformer = CogVideoXTransformer3DModel.from_config(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.load_state_dict(state_dict, assign=True, strict=True)
# 3. Create pipeline and run inference
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(
prompt=prompt,
guidance_scale=6,
use_dynamic_cfg=True,
num_inference_steps=50,
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203
).frames[0]
export_to_video(video, "output.mp4", fps=8)
@a-r-r-o-w
Copy link
Author

The following results are from an A100, 80 GB.

model_type compile fuse_qkv quantize_vae quantization model_memory inference_memory time quality
5B False False False fp16 19.764 31.746 258.962 Good
5B False True False fp16 21.979 33.961 257.761 Good
5B True False False fp16 19.763 31.742 225.998 Good
5B True True False fp16 21.979 33.961 225.814 Good
5B False False False bf16 19.764 31.746 243.312 Good
5B False True False bf16 21.979 33.96 242.519 Good
5B True False False bf16 19.763 31.742 212.022 Good
5B True True False bf16 21.979 33.961 211.377 Good
5B False False False int8wo 10.302 22.288 260.036 Okay
5B False True False int8wo 11.414 23.396 271.627 Okay
5B True False False int8wo 10.301 22.282 205.899 Okay
5B True True False int8wo 11.412 23.397 209.640 Okay
5B False False False int8dq 10.3 22.287 550.239 Okay
5B False True False int8dq 11.414 23.399 530.113 Okay
5B True False False int8dq 10.3 22.286 177.256 Okay
5B True True False int8dq 11.414 23.399 177.666 Okay
5B False False False int4wo 6.237 18.221 1130.86 Bad (Color)
5B False True False int4wo 6.824 18.806 1127.56 Bad (Color)
5B True False False int4wo 6.235 18.217 1068.31 Bad (Color)
5B True True False int4wo 6.825 18.809 1067.26 Bad (Color)
5B False False False int4dq 11.48 23.463 340.204 Okay
5B False True False int4dq 12.785 24.771 323.873 Okay
5B True False False int4dq 11.48 23.466 219.393 Okay
5B True True False int4dq 12.785 24.774 218.592 Okay
5B False False False fp6 7.902 19.886 283.478 Bad (Overflow)
5B False True False fp6 8.734 20.718 281.083 Bad (Overflow)
5B True False False fp6 7.9 19.885 205.123 Bad (Overflow)
5B True True False fp6 8.734 20.719 204.564 Bad (Overflow)
5B False False False autoquant 19.763 24.938 540.621 Good
5B False True False autoquant 21.978 27.1 504.031 Good
5B True False False autoquant 19.763 24.73 176.794 Good
5B True True False autoquant 21.978 26.948 177.122 Good
5B False False False sparsify 6.743 18.727 308.767 Bad (Patched)
5B False True False sparsify 7.439 19.433 300.013 Bad (Patched)
2B False False False fp16 12.535 24.511 96.918 Good
2B False True False fp16 13.169 25.142 96.610 Good
2B True False False fp16 12.524 24.498 83.938 Good
2B True True False fp16 13.169 25.143 84.694 Good
2B False False False bf16 12.55 24.528 93.896 Good
2B False True False bf16 13.194 25.171 93.396 Good
2B True False False bf16 12.486 24.526 81.224 Good
2B True True False bf16 13.13 25.171 81.520 Good
2B False False False fp6 6.125 18.164 95.684 Bad (Overflow)
2B False True False fp6 6.769 18.808 91.698 Bad (Overflow)
2B True False False fp6 6.125 18.164 72.261 Bad (Overflow)
2B True True False fp6 6.767 18.808 90.585 Bad (Overflow)
2B False False False int8wo 6.58 18.621 102.941 Okay
2B False True False int8wo 6.894 18.936 102.403 Okay
2B True False False int8wo 6.577 18.618 81.389 Okay
2B True True False int8wo 6.891 18.93 83.079 Okay
2B False False False int8dq 6.58 18.621 197.254 Good
2B False True False int8dq 6.894 18.936 190.125 Good
2B True False False int8dq 6.58 18.621 75.16 Good
2B True True False int8dq 6.891 18.933 74.981 Good
2B False False False int4dq 7.344 19.385 132.155 Okay
2B False True False int4dq 7.762 19.743 122.657 Okay
2B True False False int4dq 7.395 19.374 83.103 Okay
2B True True False int4dq 7.762 19.741 82.642 Okay
2B False False False int4wo 4.155 16.138 363.792 Okay
2B False True False int4wo 4.345 16.328 361.839 Okay
2B True False False int4wo 4.155 16.139 342.817 Okay
2B True True False int4wo 4.354 16.339 341.48 Okay
2B False False False autoquant 12.55 19.734 185.023 Good
2B False True False autoquant 13.194 20.319 177.602 Good
2B True False False autoquant 12.55 19.565 75.005 Good
2B True True False autoquant 13.195 20.191 74.807 Good
2B False False False sparsify 4.445 16.431 125.59 Bad (Patched)
2B False True False sparsify 4.652 16.635 121.357 Bad (Patched)

@a-r-r-o-w
Copy link
Author

The following results are from an H100.

model_type compile fuse_qkv quantize_vae quantization model_memory inference_memory time
5B False True False fp16 21.978 33.988 113.945
5B True True False fp16 21.979 33.99 87.155
5B False True False bf16 21.979 33.988 112.398
5B True True False bf16 21.979 33.987 87.455
5B False True False fp8 11.374 23.383 113.167
5B True True False fp8 11.374 23.383 75.255
5B False True False int8wo 11.414 23.422 123.144
5B True True False int8wo 11.414 23.423 87.026
5B True True False int8dq 11.412 59.355 78.945
5B False True False int4dq 12.785 24.793 151.242
5B True True False int4dq 12.785 24.795 87.403
5B False True False int4wo 6.824 18.829 667.125

@a-r-r-o-w
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment