|
GPU fp32 |
GPU fp16 |
TPU fp32* |
TPU fp16* |
Largest batch size |
32 |
64 |
1024 |
2048 |
Min-inference batch size |
4 |
64 |
1024 |
2048 |
Masker inference (s/i) |
0.059 |
0.019 |
3.80e-5 |
1.36e-5 |
Painter inference (s/i) |
0.068 |
0.041 |
2.53e-5 |
1.16e-5 |
Inference loop (s) |
130.382 |
60.567 |
0.073 |
0.0392 |
Inference loop (i/s) |
~8 |
~17 |
~14 000 |
~26 000 |
Full dataset with loading (s) |
151.546 |
76.953 |
18.05 |
15.31 |
Total Device -> CPU (s) |
2.816 |
2.528 |
inf |
inf |
Largest batch size
=> largest batch size that would fit in memory (for multiples of 2)
Min-inference batch size
=> Batch size associated with the smallest pure on-device inference time. We look at this metric because we assume linear loading time
Masker inference
=> Average time per image for the masker's inferencs (s/i
= seconds/image)
Painter inference
=> Average time per image for the painter's inferences (s/i
= seconds/image )
Masker + painter inference
=> Number of images per second for pure inference (i/s
)
Inference loop
=> smallest on-device inference time for the entire dataset
Full dataset with loading
=> overall time to process the entire dataset: numpy array -> torch tensor -> transforms -> inference but not back to cpu
Device -> CPU
=> (Average time taken to get the inference back trom the device for 1 batch) * (number of batches)
Comments
- data => 1024 images (a list of 100 different images, repeated 5+ times). Because it gets really long on GPU and TPUs still have theback-to-cpu issue I did not try larger batches. Images have a wide range of shapes but are all transformed into a
3 x 640 x 640
tensor
- TPU fp32 and TPU fp16 => Numbers in this column where computed after loading 4096 images to have respectively 4 and 2 batches. So to account for 4 times more data than GPU columns, measures for
Inference loop
and Full dataset with loading
are divided by 4 (which could be slightly off).
inf
=> after 10+ minutes, still no answer. Cannot even stop the process with ctrl+c
, have to kill it (kill -9
)
TPU fp16
=> prepend XLA_USE_BF16=1
to command
- TPU: time to perform transforms (per sample): 0.004, time to send to device (per sample): 0.011
- Numbers on TPU have a high variance (not something I measured, but observed, some full processings take 60 seconds others 75 or 80 with the same params)
import time
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
class Timer:
def __init__(self, name="", store=None, precision=3):
self.name = name
self.precision = precision
def format(self, n):
return f"{n:.{self.precision}f}"
def __enter__(self):
"""Start a new timer as a context manager"""
self._start_time = time.perf_counter()
return self
def __exit__(self, *exc_info):
"""Stop the context manager timer"""
t = time.perf_counter()
new_time = t - self._start_time
print(f"[{self.name}] Elapsed time: {self.format(new_time)}")
if __name__ == "__main__":
device = xm.xla_device()
torch.set_grad_enabled(False)
model = nn.Sequential(
*[
nn.Conv2d(3, 256, 3, 1, 1),
nn.Conv2d(256, 512, 3, 1, 1),
nn.Conv2d(512, 256, 3, 1, 1),
nn.Conv2d(256, 3, 3, 1, 1),
]
).to(device)
data = torch.rand(2, 3, 640, 640, device=device)
with Timer("inference", precision=6):
y = model(data)
print(y.shape)
with Timer("back from device"):
y = y.cpu().numpy()
[inference] Elapsed time: 0.000643
torch.Size([2, 3, 640, 640])
[back from device] Elapsed time: 11.328
in an ordinary training script, such as our imagenet example, you'll notice that we don't actually use
xm.mark_step
. This is because it's called every time the xla dataloader yields. Here in this script, you're not using a dataloader so you'll have to call it yourself.As to your last comment; if you'd like to measure the two models separately, you'd do that, yes. I'd do the whole thing 10-20 times first to get rid of compilation overhead, and then measure as pointed out in my first comment.
Also, note that every time you mark step, you'll force execution, and most of the time you'll get better performance if you don't do that; i.e.
should be faster than
The exception here is if the total IR graph the first snippet creates is too large and executing the graph actually causes an OOM or swapping in the TPU host.