Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Created April 28, 2023 04:37
Show Gist options
  • Save yzhliu/1c709ae4a4f06407f8a7a01add833b8b to your computer and use it in GitHub Desktop.
Save yzhliu/1c709ae4a4f06407f8a7a01add833b8b to your computer and use it in GitHub Desktop.
import argparse
import torch
pt_dtype_mappings = {
"float": torch.float,
"half": torch.half,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
def parse_args():
"""Define command-line arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
parser.add_argument(
"--sequence-length", type=int, default=2048, help="Sequence length"
)
parser.add_argument("--hidden-size", type=int, default=8192, help="Hidden size")
parser.add_argument("--warmup", type=int, default=10, help="Warmp-up iterations")
parser.add_argument(
"--iterations",
type=int,
default=100,
help="The number of repeat matmul iterations",
)
parser.add_argument(
"--dtype", type=str, default="float", help="Precision of the tensor"
)
parser.add_argument(
"--fp8", action="store_true", default=False, help="Whether to use fp8"
)
return parser.parse_args()
def run(
batch_size, sequence_size, hidden_size, warmup=10, iterations=100, dtype="float", fp8=False
):
torch.cuda.set_device(1)
a = torch.randn(batch_size, sequence_size, hidden_size, dtype=dtype, device="cuda:0")
b = torch.randn(hidden_size, hidden_size, dtype=dtype, device="cuda:0")
c = torch.randn(batch_size, sequence_size, hidden_size, dtype=dtype, device="cuda:0")
for _ in range(warmup):
torch.matmul(a, b, out=c)
torch.cuda.synchronize(device=0)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if fp8:
print("Using fp8")
from transformer_engine.common import recipe
import transformer_engine.pytorch as te
fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
start.record()
for _ in range(iterations):
torch.matmul(a, b, out=c)
else:
start.record()
for _ in range(iterations):
torch.matmul(a, b, out=c)
torch.cuda.synchronize(device=0)
end.record()
tflops = (
2
* batch_size
* sequence_size
* hidden_size**2
* iterations
/ start.elapsed_time(end)
/ 10**9
)
print(
f"The TFLOPS for computing matmul between tensor ({batch_size}, {sequence_size}, {hidden_size}) and ({hidden_size}, {hidden_size}) is {tflops}"
)
if __name__ == "__main__":
args = parse_args()
batch_size = args.batch_size
sequence_size = args.sequence_length
hidden_size = args.hidden_size
warmup = args.warmup
iterations = args.iterations
dtype = pt_dtype_mappings[args.dtype]
run(batch_size, sequence_size, hidden_size, warmup, iterations, dtype, args.fp8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment