Last active
April 23, 2024 13:09
-
-
Save garrett361/7161235587a2ff51306764fe488b9431 to your computer and use it in GitHub Desktop.
xpu and cuda matmul timing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import perf_counter | |
from typing import Optional, Union | |
import torch | |
if torch.cuda.is_available(): | |
from torch import cuda as accel | |
device = "cuda" | |
else: | |
import intel_extension_for_pytorch as ipex # noqa | |
from torch import xpu as accel | |
device = "xpu" | |
# Debugging printing: | |
print(f"Using {device=}. {accel.device_count()=}", flush=True) | |
DTYPE = torch.bfloat16 | |
def time_matmul_perf_counter( | |
lin: torch.nn.Linear, | |
t: torch.Tensor, | |
num_iters: int, | |
cache: Optional[torch.Tensor], | |
clear_cache: bool, | |
) -> list[float]: | |
"""Time A @ B with perf_counter and return the list of measurements in seconds.""" | |
times = [] | |
for _ in range(num_iters): | |
if clear_cache: | |
cache.zero_() | |
accel.synchronize() | |
start = perf_counter() | |
lin(t) | |
accel.synchronize() | |
stop = perf_counter() | |
times.append(stop - start) | |
return times | |
def time_matmul_events( | |
lin: torch.nn.Linear, | |
t: torch.Tensor, | |
num_iters: int, | |
cache: Optional[torch.Tensor], | |
clear_cache: bool, | |
) -> list[float]: | |
"""Time A @ B with Events and return the list of measurements in seconds.""" | |
starts = [accel.Event(enable_timing=True) for _ in range(num_iters)] | |
stops = [accel.Event(enable_timing=True) for _ in range(num_iters)] | |
accel.synchronize() | |
for start, stop in zip(starts, stops): | |
if clear_cache: | |
cache.zero_() | |
start.record() | |
lin(t) | |
stop.record() | |
accel.synchronize() | |
times = [start.elapsed_time(stop) / 1e3 for start, stop in zip(starts, stops)] | |
return times | |
def benchmark( | |
batch_size: int, | |
m: int, | |
k: int, | |
n: int, | |
warmups: int, | |
num_iters: int, | |
time_with_events: bool = False, | |
cache_size_MiB: int = 256, | |
clear_cache: bool = True, | |
optimize_xpu: bool = False, | |
) -> dict[str, Union[int, float]]: | |
""" | |
Benchmarking m x k by k x n matmuls. | |
""" | |
assert warmups > 0, "Use at least one warmup" | |
cache = ( | |
torch.empty(cache_size_MiB * 2**20, dtype=torch.int8, device=device) | |
if clear_cache | |
else None | |
) | |
with torch.inference_mode(): | |
lin = torch.nn.Linear(k, n, bias=False, device=device, dtype=DTYPE) | |
if optimize_xpu: | |
assert device == "xpu" | |
lin.eval() # optimize requires specifying an optimizer when .training() is True | |
lin = torch.xpu.optimize(lin, dtype=DTYPE, level="O1") | |
t = torch.randn(batch_size, m, k, device=device, dtype=DTYPE) | |
for _ in range(warmups): | |
lin(t) | |
# NOTE: @garrett.goon - using Events (as in torch.cuda.Event) is the preferred way to time | |
# GPU operations, but the xpu.Event timing gave strange results when 2.1.0a0+cxx11.abi + | |
# ipex. TBD whether this is an xpu or code issue. See | |
# https://github.com/intel/intel-extension-for-pytorch/issues/568 | |
if time_with_events: | |
times = time_matmul_events(lin, t, num_iters, cache, clear_cache) | |
else: | |
times = time_matmul_perf_counter(lin, t, num_iters, cache, clear_cache) | |
del lin | |
del t | |
accel.empty_cache() | |
TFLOPs = (2 * k - 1) * batch_size * m * n / 1e12 | |
tflop_per_s = torch.tensor([TFLOPs / t for t in times]) | |
times_tensor = torch.tensor(times) | |
return { | |
"time": times_tensor.mean().item(), | |
"time_err": times_tensor.std().item(), | |
"TFLOPs": TFLOPs, | |
"TFLOP/s": tflop_per_s.mean().item(), | |
"TFLOP/s_err": tflop_per_s.std().item(), | |
"m": m, | |
"n": n, | |
"k": k, | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import seaborn as sns | |
import torch | |
import tqdm | |
from matmul_bench import benchmark | |
sns.set_theme(style="darkgrid", rc={"figure.figsize": (10, 10)}) | |
""" | |
Script for benchmarking square matmuls and creating a plot. | |
""" | |
def code_by_divisibility(n, max_exp=4): | |
"""Returns the largest factor = 2 ** exp that n is divisible by, for n in {0, ..., max_exp}.""" | |
for exp in reversed(range(max_exp + 1)): | |
factor = 2**exp | |
num_elements = n.k | |
if not num_elements % factor: | |
return factor | |
def main() -> None: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--min-dim", default=512, type=int) | |
parser.add_argument("--max-dim", default=4096, type=int) | |
parser.add_argument("--step", default=256, type=int) | |
parser.add_argument("--warmups", default=3, type=int) | |
parser.add_argument("--num-iters", default=10, type=int) | |
parser.add_argument("--cache-size-MiB", default=256, type=int) | |
parser.add_argument("--clear-cache", action="store_true") | |
parser.add_argument("--time-with-events", action="store_true") | |
parser.add_argument("--optimize-xpu", action="store_true") | |
parser.add_argument("--file", default="", type=str) | |
args = parser.parse_args() | |
results = [] | |
for dim in tqdm.trange(args.min_dim, args.max_dim + 1, args.step): | |
result = benchmark( | |
batch_size=1, | |
m=dim, | |
k=dim, | |
n=dim, | |
warmups=args.warmups, | |
num_iters=args.num_iters, | |
time_with_events=args.time_with_events, | |
cache_size_MiB=args.cache_size_MiB, | |
clear_cache=args.clear_cache, | |
optimize_xpu=args.optimize_xpu, | |
) | |
results.append(result) | |
results_df = pd.DataFrame(results) | |
# Code by divisibility of matrix size | |
results_df["divisible_by"] = results_df.apply(code_by_divisibility, axis=1) | |
plot = sns.scatterplot( | |
x="m", y="TFLOP/s", data=results_df, hue="divisible_by", palette="Set2", zorder=2 | |
) | |
plt.errorbar( | |
x="m", | |
y="TFLOP/s", | |
yerr="TFLOP/s_err", | |
data=results_df, | |
ls="none", | |
ecolor="silver", | |
zorder=1, | |
) | |
device = "cuda" if torch.cuda.is_available() else "xpu" | |
plot.set(title=f"Square matrix multiplies (batch_size = 1, {device=})") | |
plot.set(xlabel="dim") | |
file = args.file or f"flops_vs_dim_{device}.png" | |
plot.figure.savefig(file, dpi=256) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
from matmul_bench import benchmark | |
""" | |
Script for benchmarking square matmuls and printing out results. | |
""" | |
def main() -> None: | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--min-dim", default=512, type=int) | |
parser.add_argument("--max-dim", default=4096, type=int) | |
parser.add_argument("--step", default=256, type=int) | |
parser.add_argument("--warmups", default=3, type=int) | |
parser.add_argument("--num-iters", default=10, type=int) | |
parser.add_argument("--cache-size-MiB", default=256, type=int) | |
parser.add_argument("--clear-cache", action="store_true") | |
parser.add_argument("--time-with-events", action="store_true") | |
args = parser.parse_args() | |
for dim in range(args.min_dim, args.max_dim + 1, args.step): | |
results = benchmark( | |
batch_size=1, | |
m=dim, | |
k=dim, | |
n=dim, | |
warmups=args.warmups, | |
num_iters=args.num_iters, | |
time_with_events=args.time_with_events, | |
cache_size_MiB=args.cache_size_MiB, | |
clear_cache=args.clear_cache, | |
) | |
print(f"{dim}x{dim} square matmul TFLOP/s: {results=}", flush=True) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment