Skip to content

Instantly share code, notes, and snippets.

@arquolo
Last active June 2, 2023 09:27
Show Gist options
  • Save arquolo/e3aee9ef9b7a9d253c83ef4558bd821a to your computer and use it in GitHub Desktop.
Save arquolo/e3aee9ef9b7a9d253c83ef4558bd821a to your computer and use it in GitHub Desktop.
Benchmark speed and batch of neural networks from PyTorch framework. Requires Python 3.8-3.11, PyTorch 2.0
import os
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from collections.abc import Callable
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from itertools import count
from time import perf_counter
os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'backend:cudaMallocAsync'
os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1'
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as tomp
from matplotlib.ticker import EngFormatter
from torch import nn
from torchvision import models
from tqdm.auto import tqdm
plt.rcParams['svg.fonttype'] = 'none'
tomp.set_sharing_strategy('file_system')
warnings.filterwarnings('ignore')
MAX_BATCH = 8192
TIMEOUT = 5 # Time in seconds for measuring speed for each run
NBITS = 5 # only high X bits of resulting batch size are precise
SHAPE = (224, 224)
OUTFILE = 'batch_test'
NETS: dict[str, Callable[..., nn.Module]] = {
name: partial(models.get_model, name, weights=None)
for name in models.list_models(module=models)
}
for _name in ('inception_v3', 'googlenet'):
if _name in NETS:
NETS[_name] = partial(NETS[_name], aux_logits=False)
YV_2_NET = { # (arxiv year, torchvision version) -> model family
# 14
(2014, '0.1.6'): ['alexnet', 'vgg'], # 1+8 / 9
(2015, '0.1.6'): ['resnet'], # 5
# 7
(2015, '0.1.8'): ['inception_v3'], # 1
(2016, '0.1.8'): ['densenet', 'squeezenet'], # 4+2 / 6
# 9
(2014, '0.3.0'): ['googlenet'], # 1
(2016, '0.3.0'): ['resnext'], # 3
(2018, '0.3.0'): ['mobilenet_v2', 'shufflenet'], # 1+4 / 5
# 6
(2016, '0.4.0'): ['wide_resnet'], # 2
(2018, '0.4.0'): ['mnasnet'], # 4
# 2
(2019, '0.9.0'): ['mobilenet_v3'], # 2
# 23
(2019, '0.11.0'): ['efficientnet'], # 8
(2020, '0.11.0'): ['regnet'], # 15
# 9
(2020, '0.12.0'): ['vit'], # 5
(2022, '0.12.0'): ['convnext'], # 4
# 6
(2021, '0.13.0'): ['efficientnet_v2', 'swin'], # 3+3 / 6
# 4
(2022, '0.14.0'): ['maxvit', 'swin_v2'], # 1+3 / 4
}
NETS = dict(sorted(NETS.items()))
def main() -> None:
# Find min/max params counts
metas = {
name: models.get_model_weights(name).DEFAULT.meta for name in NETS
}
param_counts = {name: meta['num_params'] for name, meta in metas.items()}
min_max_params = min(param_counts.values()), max(param_counts.values())
# Parse args
p = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
p.add_argument('-d', '--device', type=int, help='device id to use')
p.add_argument(
'-y',
'--year',
type=int,
choices=sorted(y for y, _ in YV_2_NET),
help='release year')
p.add_argument('-i', '--show', action='store_true')
p.add_argument('-t', '--tag')
p.add_argument('--no-fp16', action='store_false', dest='fp16')
p.add_argument(
'--params',
type=float,
nargs=2,
default=min_max_params,
help='model params')
args = p.parse_args()
min_params, max_params = args.params
min_params = max(min_max_params[0], min_params)
max_params = min(min_max_params[1], max_params)
# Group model by their families
print(f'Found {len(NETS)} models to test')
yr_nets: dict[int, dict[str, Callable[..., nn.Module]]] = {}
for net, net_fn in NETS.items():
if not (min_params <= param_counts[net] <= max_params):
continue
year = max(
(yr for (yr, _), fams in YV_2_NET.items()
if any(map(net.startswith, fams))),
default=-1,
)
yr_nets.setdefault(year, {})[net] = net_fn
yr_nets = dict(sorted(yr_nets.items()))
num_selected = sum(len(ns) for ns in yr_nets.values())
print(f'Selected {num_selected} models:')
print(
*sorted((param_counts[n], n) for ns in yr_nets.values() for n in ns),
sep='\n')
# Run
print(f'Using {SHAPE} shape for image and {TIMEOUT}s per run')
if num_selected < len(NETS): # All to one group
yr_nets = {0: {k: v for f in yr_nets.values() for k, v in f.items()}}
with tqdm(yr_nets.items(), desc='measuring') as bar:
for year, nets in bar:
bar.set_postfix(year=year)
if args.year not in (None, year):
continue
outfile = f'{OUTFILE}.{year}-{len(nets)}'
if args.tag is not None:
outfile = f'{outfile}.{args.tag}'
_measure_n_plots(nets, outfile, args, metas)
def _measure_n_plots(nets: dict[str, Callable[..., nn.Module]], outfile: str,
args, metas):
stats = []
for name, net_fn in tqdm(nets.items()):
net = net_fn()
size = sum(t.numel() for ts in (net.parameters, net.buffers)
for t in ts()) / 1e6
in1k_top1 = metas[name]['_metrics']['ImageNet-1K']['acc@1']
in1k_top5 = metas[name]['_metrics']['ImageNet-1K']['acc@5']
name = f'{name}/{size:.1f}M'
be16, se16 = _find_max_batch(net, fp16=True, dev=args.device)
be32, se32 = _find_max_batch(
net,
high=be16 or MAX_BATCH,
dev=args.device,
)
bt16, st16 = _find_max_batch(
net,
is_train=True,
fp16=True,
high=be16 or MAX_BATCH,
dev=args.device,
)
bt32, st32 = _find_max_batch(
net,
is_train=True,
high=be32,
dev=args.device,
)
stats.append({
'name': name,
'train/batch': (bt16, bt32),
'train/fps': (st16, st32),
'infer/batch': (be16, be32),
'infer/fps': (se16, se32),
'in1k_err': (round(100 - in1k_top1, 3), round(100 - in1k_top5, 3)),
})
tqdm.write(f'{name}: '
f'{bt32}..{be16 or be32} bs, '
f'{st32}...{se16 or se32} fps')
df = pd.DataFrame.from_records(stats)
df = df.sort_values('in1k_err')
# df = df.sort_values('train/batch')
df.to_csv(f'{outfile}.csv', index=False)
fig, ax = plt.subplots(figsize=(8, 16))
ax.set_xscale('log', base=2)
ax.set(xlim=(1, 16384))
ax.xaxis.set_major_formatter(EngFormatter())
ax.grid(True)
xs = np.arange(df.shape[0], dtype=float)
cmap = plt.get_cmap('tab10')
labels = {
'train/batch': ['fp16', 'fp32'],
'train/fps': ['fp16', 'fp32'],
'infer/batch': ['fp16', 'fp32'],
'infer/fps': ['fp16', 'fp32'],
'in1k_err': ['top@1', 'top@5'],
}
akwds = {
'fontsize': 8,
'xytext': (0, 0),
'textcoords': 'offset points',
'va': 'center',
}
n = len(labels)
h = 1 / (n + 1)
offsets = np.linspace(h - .5, .5 - h, n).tolist()
for i, (offset, (l_head, l_tails)) in enumerate(
zip(offsets, labels.items())):
bkwds = {'height': h, 'color': cmap(i)}
fp16, fp32 = zip(*df[l_head].values.tolist())
for l_tail, dat, alpha, ha in zip(l_tails, [fp16, fp32], [0.5, None],
['left', 'right']):
if dat is None:
continue
label_ = f'{l_head}/{l_tail}'
bar = ax.barh(xs - offset, dat, label=label_, alpha=alpha, **bkwds)
for r in bar:
xy = (r.get_width(), r.get_y() + r.get_height() / 2)
ax.annotate(r.get_width(), xy=xy, ha=ha, **akwds)
ax.set_yticks(xs)
ax.set_yticklabels(df['name'].values.tolist())
ax.legend()
fig.tight_layout()
fig.savefig(f'{outfile}.svg')
if args.show:
plt.show()
def _find_max_batch(net: nn.Module,
is_train: bool = False,
fp16: bool = False,
low: int = 0,
high: int = MAX_BATCH,
dev: int | None = None) -> tuple[int, int]:
speed = 0.
exc_ = None
rg = sorted({_ev_round(x, NBITS) for x in range(low, high + 1)})
with tqdm(
desc=(('infer', 'train')[is_train] + f'/fp{(32, 16)[fp16]}'),
leave=False) as bar:
while len(rg) > 2:
mid_pos = len(rg) // 2
batch = rg[mid_pos]
try:
speed_ = ProcessPoolExecutor(1).submit(
_test_speed, net, fp16, is_train, batch, dev).result()
# Not supported CUDA
except _LegacyDeviceError:
return 0, 0
# Out of memory. Shrink
except (
RuntimeError,
torch.cuda.OutOfMemoryError, # type:ignore[misc]
) as exc:
rg, exc_ = rg[:mid_pos + 1], exc
# Not all mem acquired. Grow
else:
rg, exc_, speed = rg[mid_pos:], None, max(speed, speed_)
bar.set_postfix_str(f'range: {rg[0]}..{rg[-1]} @ {len(rg)} items')
bar.update()
if rg[0] == 0 and exc_ is not None:
raise exc_ from None
return rg[0], int(speed)
def _ev_round(x: int, bits: int = 5):
nbits = max(x.bit_length() - bits, 0)
return (x >> nbits) << nbits
def _test_speed(net: nn.Module,
fp16: bool,
is_train: bool,
batch: int,
dev: int | None = None) -> float:
devs = list(range(torch.cuda.device_count())) if dev is None else [dev]
dprops = *map(torch.cuda.get_device_properties, devs),
if len({dp.name for dp in dprops}) > 1: # Some GPUs differ, use first
devs, dprops = devs[:1], dprops[:1]
if fp16 and dprops[0].major < 7: # Forbid FP16 for pre sm7x GPUs
raise _LegacyDeviceError
max_mem = sum(dp.total_memory for dp in dprops)
if len(devs) > 1:
net = torch.nn.DataParallel(net)
net.cuda(dev).train(is_train)
data = torch.rand(batch, 3, *SHAPE, device=f'cuda:{devs[0]}')
do_step = partial(_step, net, data, is_train=is_train, fp16=fp16)
with torch.set_grad_enabled(is_train):
start = perf_counter()
for n in count(batch, batch):
do_step()
# Forbid VRAM extension via RAM to not halt performance.
used_mem = sum(
max(s[f'{t}_bytes.all.peak']
for t in ('active', 'allocated', 'reserved'))
for s in map(torch.cuda.memory_stats, devs))
if used_mem >= max_mem:
raise torch.cuda.OutOfMemoryError # type:ignore[misc]
if (done := perf_counter() - start) > TIMEOUT:
return n / done
return 0.
def _step(net: nn.Module, data: torch.Tensor, is_train: bool, fp16: bool):
with torch.autocast('cuda', enabled=fp16):
loss = net(data).sum()
if is_train:
net.zero_grad()
loss.backward()
loss.item()
class _LegacyDeviceError(Exception):
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment