Skip to content

Instantly share code, notes, and snippets.

@driazati
Created February 4, 2020 19:51
Show Gist options
  • Save driazati/a5bee3ae642b144ad089e8b30eff6d3f to your computer and use it in GitHub Desktop.
Save driazati/a5bee3ae642b144ad089e8b30eff6d3f to your computer and use it in GitHub Desktop.
import torch
from pyarkbench import default_args, Benchmark, Timer
class Bench(Benchmark):
def __init__(self, x, out, use_new):
super().__init__(*default_args.bench())
self.x = x
self.output = out
self.use_new = use_new
def benchmark(self):
out = self.output
if not isinstance(out, str):
out = out()
with Timer() as save_time:
torch.save(self.x, out, _use_new_zipfile_serialization=self.use_new)
if hasattr(out, 'seek'):
out.seek(0)
with Timer() as load_time:
torch.load(out)
return {
"save_time": save_time.ms_duration,
"load_time": load_time.ms_duration
}
tests = [
("New to file (4 elements)", torch.ones(2, 2), 'model.zip', True),
("Old to file (4 elements)", torch.ones(2, 2), 'model.zip', False),
("New to buffer (4 elements)", torch.ones(2, 2), io.BytesIO, True),
("Old to buffer (4 elements)", torch.ones(2, 2), io.BytesIO, False),
("New to file (512^2 elements)", torch.ones(512 * 512), 'model.zip', True),
("Old to file (512^2 elements)", torch.ones(512 * 512), 'model.zip', False),
("New to buffer (512^2 elements)", torch.ones(512 * 512), io.BytesIO, True),
("Old to buffer (512^2 elements)", torch.ones(512 * 512), io.BytesIO, False),
]
for name, x, out, use in tests:
bench = Bench(x, out, use)
results = bench.run()
print(name)
bench.print_results(results)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment