Last active
November 11, 2024 04:01
-
-
Save a-r-r-o-w/4d9732d17412888c885480c6521a9897 to your computer and use it in GitHub Desktop.
Demonstrates how to use CogVideoX 2B/5B with Diffusers and TorchAO
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
compile_flags=("" "--compile") | |
fuse_qkv_flags=("" "--fuse_qkv") | |
# quantizations=("fp16" "bf16" "fp8" "fp8_e4m3" "fp8_e5m2" "fp6" "int8wo" "int8dq" "int4dq" "int4wo" "autoquant" "sparsify") | |
quantizations=("fp16" "bf16" "fp6" "int8wo" "int8dq" "int4dq" "int4wo" "autoquant" "sparsify") | |
device="cuda" | |
# Check if completed.txt exists and read it into an array | |
if [ -f completed.txt ]; then | |
mapfile -t completed_runs < completed.txt | |
else | |
completed_runs=() | |
fi | |
for quantization in "${quantizations[@]}"; do | |
for compile in "${compile_flags[@]}"; do | |
for fuse_qkv in "${fuse_qkv_flags[@]}"; do | |
cmd="python3 cogvideox-torchao-benchmark.py $compile $fuse_qkv --dtype $quantization --device $device" | |
# Check if the command is in the list of completed runs | |
if [[ " ${completed_runs[@]} " =~ " ${cmd} " ]]; then | |
echo "Skipping already completed command: $cmd" | |
continue | |
fi | |
echo "Running command: $cmd" | |
eval $cmd | |
echo -ne "------------------ Finished executing script ------------------\n\n" | |
done | |
done | |
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import gc | |
import os | |
import time | |
os.environ["TORCH_LOGS"] = "dynamo,output_code,graph_breaks,recompiles" | |
import torch | |
import torch.utils.benchmark as benchmark | |
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, CogVideoXDDIMScheduler | |
from diffusers.utils import export_to_video | |
from tabulate import tabulate | |
from transformers import T5EncoderModel | |
from torchao.quantization import ( | |
autoquant, | |
quantize_, | |
int8_weight_only, | |
int8_dynamic_activation_int8_weight, | |
int8_dynamic_activation_int4_weight, | |
int8_dynamic_activation_int8_semi_sparse_weight, | |
int4_weight_only, | |
) | |
from torchao.sparsity import sparsify_ | |
from torchao.float8.inference import ActivationCasting, QuantConfig, quantize_to_float8 | |
from torchao.prototype.quant_llm import fp6_llm_weight_only | |
torch.set_float32_matmul_precision("high") | |
torch._inductor.config.conv_1x1_as_mm = True | |
torch._inductor.config.coordinate_descent_tuning = True | |
torch._inductor.config.epilogue_fusion = False | |
torch._inductor.config.coordinate_descent_check_all_directions = True | |
DTYPE_CONVERTER = { | |
"fp32": lambda module: module.to(dtype=torch.float32), | |
"fp16": lambda module: module.to(dtype=torch.float16), | |
"bf16": lambda module: module.to(dtype=torch.bfloat16), | |
"fp8": lambda module: quantize_to_float8(module, QuantConfig(ActivationCasting.DYNAMIC)), | |
"fp8_e4m3": lambda module: module.to(dtype=torch.float8_e4m3fn), | |
"fp8_e5m2": lambda module: module.to(dtype=torch.float8_e5m2), | |
"fp6": lambda module: quantize_(module, fp6_llm_weight_only()), | |
"int8wo": lambda module: quantize_(module, int8_weight_only()), | |
"int8dq": lambda module: quantize_(module, int8_dynamic_activation_int8_weight()), | |
"int4dq": lambda module: quantize_(module, int8_dynamic_activation_int4_weight()), | |
"int4wo": lambda module: quantize_(module, int4_weight_only()), | |
"autoquant": lambda module: autoquant(module, error_on_unseen=False), | |
"sparsify": lambda module: sparsify_(module, int8_dynamic_activation_int8_semi_sparse_weight()), | |
} | |
def benchmark_fn(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", | |
globals={"args": args, "kwargs": kwargs, "f": f}, | |
num_threads=torch.get_num_threads(), | |
) | |
return f"{(t0.blocked_autorange().mean):.3f}" | |
def reset_memory(device): | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats(device) | |
torch.cuda.reset_accumulated_memory_stats(device) | |
def print_memory(device): | |
memory = torch.cuda.memory_allocated(device) / 1024**3 | |
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 | |
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 | |
print(f"{memory=:.3f}") | |
print(f"{max_memory=:.3f}") | |
print(f"{max_reserved=:.3f}") | |
def pretty_print_results(results, precision: int = 6): | |
def format_value(value): | |
if isinstance(value, float): | |
return f"{value:.{precision}f}" | |
return value | |
filtered_table = {k: format_value(v) for k, v in results.items()} | |
print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center")) | |
def load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv): | |
# 1. Load pipeline | |
pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device) | |
pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
pipe.set_progress_bar_config(disable=True) | |
if fuse_qkv: | |
pipe.fuse_qkv_projections() | |
# 2. Quantize and compile | |
if dtype == "autoquant" and compile: | |
pipe.transformer.to(memory_format=torch.channels_last) | |
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) | |
# VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30 | |
text_encoder_return = DTYPE_CONVERTER[dtype](pipe.text_encoder) | |
transformer_return = DTYPE_CONVERTER[dtype](pipe.transformer) | |
vae_return = None | |
if dtype in ["fp32", "fp16", "bf16", "fp8_e4m3", "fp8_e5m2"] or quantize_vae: | |
vae_return = DTYPE_CONVERTER[dtype](pipe.vae) | |
if text_encoder_return is not None: | |
pipe.text_encoder = text_encoder_return | |
if transformer_return is not None: | |
pipe.transformer = transformer_return | |
if vae_return is not None: | |
pipe.vae = vae_return | |
if dtype != "autoquant" and compile: | |
pipe.transformer.to(memory_format=torch.channels_last) | |
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True) | |
# VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30 | |
return pipe | |
def run_inference(pipe): | |
prompt = ( | |
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " | |
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " | |
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " | |
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " | |
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " | |
"atmosphere of this unique musical performance." | |
) | |
video = pipe( | |
prompt=prompt, | |
guidance_scale=6, | |
num_inference_steps=50, | |
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203 | |
) | |
return video | |
def main(dtype, device, quantize_vae, compile, fuse_qkv): | |
# 1. Load pipeline | |
# model_id = "THUDM/CogVideoX-5b" # or "THUDM/CogVideoX-2b" | |
model_id = "THUDM/CogVideoX-5b" | |
pipe = load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv) | |
reset_memory(device) | |
print_memory(device) | |
torch.cuda.empty_cache() | |
model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3) | |
# 2. Warmup | |
num_warmups = 2 | |
for _ in range(num_warmups): | |
video = run_inference(pipe) | |
# 3. Benchmark | |
time = benchmark_fn(run_inference, pipe) | |
print_memory(device) | |
torch.cuda.empty_cache() | |
inference_memory = round(torch.cuda.max_memory_allocated() / 1024**3, 3) | |
# 4. Save results | |
model_type = "5B" if "5b" in model_id else "2B" | |
info = { | |
"model_type": model_type, | |
"compile": compile, | |
"fuse_qkv": fuse_qkv, | |
"quantize_vae": quantize_vae, | |
"quantization": dtype, | |
"model_memory": model_memory, | |
"inference_memory": inference_memory, | |
"time": time, | |
} | |
pretty_print_results(info, precision=3) | |
export_to_video( | |
video.frames[0], f"output-quantization_{dtype}-compile_{compile}-fuse_qkv_{fuse_qkv}-{model_type}.mp4", fps=8 | |
) | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--dtype", | |
type=str, | |
default="fp16", | |
choices=[ | |
"fp32", | |
"fp16", | |
"bf16", | |
"fp8", | |
"fp8_e4m3", | |
"fp8_e5m2", | |
"fp6", | |
"int8wo", | |
"int8dq", | |
"int4dq", | |
"int4wo", | |
"autoquant", | |
"sparsify", | |
], | |
) | |
parser.add_argument("--device", type=str, default="cuda") | |
parser.add_argument("--quantize_vae", action="store_true", default=False) | |
parser.add_argument("--compile", action="store_true", default=False) | |
parser.add_argument("--fuse_qkv", action="store_true", default=False) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = get_args() | |
main(args.dtype, args.device, args.quantize_vae, args.compile, args.fuse_qkv) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Install torchao from source and Pytorch Nightly | |
# Other environments have not yet been tested. | |
import tempfile | |
import torch | |
from diffusers import CogVideoXTransformer3DModel, CogVideoXPipeline | |
from diffusers.utils import export_to_video | |
from torchao.quantization import ( | |
quantize_, | |
int4_weight_only, | |
int8_dynamic_activation_int4_weight, | |
int8_weight_only, | |
int8_dynamic_activation_int8_weight, | |
) | |
# Either "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b" | |
model_id = "THUDM/CogVideoX-5b" | |
# 1. Quantize and save the transformer | |
transformer = CogVideoXTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) | |
quantize_(transformer, int8_weight_only()) | |
with tempfile.NamedTemporaryFile() as file: | |
torch.save(transformer.state_dict(), file) | |
file.seek(0) | |
state_dict = torch.load(file, map_location="cpu") | |
# 2. Create new model and load quantized state dict | |
transformer = CogVideoXTransformer3DModel.from_config(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) | |
transformer.load_state_dict(state_dict, assign=True, strict=True) | |
# 3. Create pipeline and run inference | |
pipe = CogVideoXPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16).to("cuda") | |
prompt = ( | |
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " | |
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " | |
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " | |
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " | |
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " | |
"atmosphere of this unique musical performance." | |
) | |
video = pipe( | |
prompt=prompt, | |
guidance_scale=6, | |
use_dynamic_cfg=True, | |
num_inference_steps=50, | |
generator=torch.Generator().manual_seed(3047), # https://arxiv.org/abs/2109.08203 | |
).frames[0] | |
export_to_video(video, "output.mp4", fps=8) |
The following results are from an H100.
model_type | compile | fuse_qkv | quantize_vae | quantization | model_memory | inference_memory | time |
---|---|---|---|---|---|---|---|
5B | False | True | False | fp16 | 21.978 | 33.988 | 113.945 |
5B | True | True | False | fp16 | 21.979 | 33.99 | 87.155 |
5B | False | True | False | bf16 | 21.979 | 33.988 | 112.398 |
5B | True | True | False | bf16 | 21.979 | 33.987 | 87.455 |
5B | False | True | False | fp8 | 11.374 | 23.383 | 113.167 |
5B | True | True | False | fp8 | 11.374 | 23.383 | 75.255 |
5B | False | True | False | int8wo | 11.414 | 23.422 | 123.144 |
5B | True | True | False | int8wo | 11.414 | 23.423 | 87.026 |
5B | True | True | False | int8dq | 11.412 | 59.355 | 78.945 |
5B | False | True | False | int4dq | 12.785 | 24.793 | 151.242 |
5B | True | True | False | int4dq | 12.785 | 24.795 | 87.403 |
5B | False | True | False | int4wo | 6.824 | 18.829 | 667.125 |
Full list of benchmarks: https://github.com/sayakpaul/diffusers-torchao
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The following results are from an A100, 80 GB.