Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created June 28, 2024 23:15
Show Gist options
  • Save drisspg/f7a553710d64cce013227a2249d582d2 to your computer and use it in GitHub Desktop.
Save drisspg/f7a553710d64cce013227a2249d582d2 to your computer and use it in GitHub Desktop.
import torch
import copy
import torch.nn as nn
from float8_experimental.inference import quantize_to_float8, ActivationCasting, QuantConfig
import torch.nn as nn
import torch.nn.functional as F
from transformer_nuggets.utils import benchmark_cuda_function_in_microseconds, profiler
from pathlib import Path
class FeedForward(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Linear(4096, 14336, bias=False)
self.w3 = nn.Linear(4096, 14336, bias=False)
self.w2 = nn.Linear(14336, 4096, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
if __name__ == "__main__":
compile_backend = "inductor"
original_mlp = FeedForward()
original_mlp = original_mlp.to("cuda").to(torch.bfloat16)
dynamic_fp8_mlp = copy.deepcopy(original_mlp)
static_fp8_mlp = copy.deepcopy(original_mlp)
weight_only_fp8_mlp = copy.deepcopy(original_mlp)
batch_size = 4
num_tokens = 1024
embedding_dim = 4096
input_tensor = torch.randn(batch_size, num_tokens, embedding_dim, device="cuda", dtype=torch.bfloat16)
compiled_original_mlp = torch.compile(original_mlp, backend=compile_backend)
with torch.no_grad():
for _ in range(10):
compiled_original_mlp(input_tensor)
with profiler(Path("/home/drisspg/meta/scripts/fp8/test_mlp_bf16")):
compiled_original_mlp(input_tensor)
bf16_time = benchmark_cuda_function_in_microseconds(compiled_original_mlp, input_tensor)
print(f"{bf16_time=}us")
with torch.no_grad():
quantize_to_float8(dynamic_fp8_mlp, QuantConfig(ActivationCasting.DYNAMIC))
compiled_dynamic_fp8_mlp = torch.compile(dynamic_fp8_mlp, backend=compile_backend)
for _ in range(10):
compiled_dynamic_fp8_mlp(input_tensor)
with profiler(Path("/home/drisspg/meta/scripts/fp8/test_mlp_fp8_dynamic_activations")):
compiled_dynamic_fp8_mlp(input_tensor)
fp8_dynamic_activations_time = benchmark_cuda_function_in_microseconds(compiled_dynamic_fp8_mlp, input_tensor)
print(f"{fp8_dynamic_activations_time=}us")
with torch.no_grad():
quantize_to_float8(static_fp8_mlp, QuantConfig(ActivationCasting.STATIC, torch.tensor([1.0], device="cuda", dtype=torch.float32)))
compiled_static_fp8_mlp = torch.compile(static_fp8_mlp, backend=compile_backend)
for _ in range(10):
compiled_static_fp8_mlp(input_tensor)
with profiler(Path("/home/drisspg/meta/scripts/fp8/test_mlp_fp8_static_activations")):
compiled_static_fp8_mlp(input_tensor)
fp8_static_activations_time = benchmark_cuda_function_in_microseconds(compiled_static_fp8_mlp, input_tensor)
print(f"{fp8_static_activations_time=}us")
with torch.no_grad():
quantize_to_float8(weight_only_fp8_mlp, QuantConfig(ActivationCasting.WEIGHT_ONLY))
compiled_weight_only_fp8_mlp = torch.compile(weight_only_fp8_mlp, backend=compile_backend)
for _ in range(10):
compiled_weight_only_fp8_mlp(input_tensor)
with profiler(Path("/home/drisspg/meta/scripts/fp8/weight_only_fp8_mlp")):
compiled_weight_only_fp8_mlp(input_tensor)
fp8_weight_only_activations_time = benchmark_cuda_function_in_microseconds(compiled_weight_only_fp8_mlp, input_tensor)
print(f"{fp8_weight_only_activations_time=}us")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment