Skip to content

Instantly share code, notes, and snippets.

@wkcn
Created October 16, 2022 15:42
Show Gist options
  • Save wkcn/0d92ff956cf1e09f20da9709f9578f9b to your computer and use it in GitHub Desktop.
Save wkcn/0d92ff956cf1e09f20da9709f9578f9b to your computer and use it in GitHub Desktop.
measure FP8 speed
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
import transformer_engine_extensions as tex
import copy
import math
from typing import Callable, Optional
def speedometer(
module: torch.nn.Module,
input: torch.Tensor,
output_grad: torch.Tensor,
timing_iters: int = 50,
warmup_iters: int = 50,
) -> None:
"""Measure average run time for a PyTorch module
Performs forward and backward passes.
"""
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# Warmup runs
torch.cuda.synchronize()
for _ in range(warmup_iters):
output = module(input)
output.backward(output_grad)
# Timing runs
start.record()
for _ in range(timing_iters):
output = module(input)
output.backward(output_grad)
end.record()
torch.cuda.synchronize()
print(f"Mean time: {start.elapsed_time(end)/timing_iters} ms")
dim = 1024
print("DIM:", dim)
m = te.Linear(dim, dim)
m.cuda()
x = torch.rand(dim, dim, device='cuda')
y = m(x)
dy = y
m1, m2 = [copy.deepcopy(m) for _ in range(2)]
x1, x2 = [x.clone() for _ in range(2)]
dy1, dy2 = [dy.clone() for _ in range(2)]
print("FP16:")
with torch.cuda.amp.autocast():
speedometer(m1, x1, dy1)
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
print("FP8:")
with torch.cuda.amp.autocast():
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
speedometer(m2, x2, dy2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment