Skip to content

Instantly share code, notes, and snippets.

@lopuhin
Created October 15, 2019 10:06
Show Gist options
  • Save lopuhin/0d100ef7df01fdfc91d9685f6e01ff64 to your computer and use it in GitHub Desktop.
Save lopuhin/0d100ef7df01fdfc91d9685f6e01ff64 to your computer and use it in GitHub Desktop.
pytorch high memory usage on CPU inference with variable input shapes
import os
os.environ['OMP_NUM_THREADS'] = '1'
import argparse
import resource
import sys
import time
import numpy as np
import torch
from torchvision.models import resnet34
def main():
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg('--width', type=int, default=320)
arg('--min-height', type=int, default=200)
arg('--max-height', type=int, default=7680)
arg('--n', type=int, default=1000)
arg('--seed', type=int, default=42)
args = parser.parse_args()
print(f'torch {torch.__version__}')
rng = np.random.RandomState(args.seed)
heights = [np.clip(int(100 * np.exp(rng.normal(1, 1.8))),
args.min_height, args.max_height)
for _ in range(args.n)]
hp50, hp95 = np.percentile(heights, [50, 95])
print(f'heights: mean={np.mean(heights):.0f}, p50={hp50:.0f} '
f'p95={hp95:.0f} max={np.max(heights):.0f}')
model = resnet34()
model.eval()
start_memory = get_ru_maxrss()
times = []
for i, height in enumerate(heights):
if i and i % 100 == 0:
print(f'n={i} memory growth (kb): '
f'{get_ru_maxrss() - start_memory:,}')
x = torch.randn((1, 3, height, args.width))
t0 = time.perf_counter()
with torch.no_grad():
y = model(x)
assert y.mean() != 42 # to be extra sure it's evaluated
times.append(time.perf_counter() - t0)
end_memory = get_ru_maxrss()
tp50, tp95 = np.percentile(times, [50, 95])
print(f'time: mean={np.mean(times):.3f} s, '
f'p50={tp50:.3f} s, p95={tp95:.3f} s')
print(f'memory (kb): {start_memory:,} initial, '
f'{end_memory - start_memory:,} growth')
def get_ru_maxrss():
""" Return max RSS usage (in kilobytes) """
size = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if sys.platform == 'darwin':
# on Mac OS X ru_maxrss is in bytes, on Linux it is in KB
size //= 1024
return size
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment