Created
February 4, 2020 19:51
-
-
Save driazati/a5bee3ae642b144ad089e8b30eff6d3f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from pyarkbench import default_args, Benchmark, Timer | |
class Bench(Benchmark): | |
def __init__(self, x, out, use_new): | |
super().__init__(*default_args.bench()) | |
self.x = x | |
self.output = out | |
self.use_new = use_new | |
def benchmark(self): | |
out = self.output | |
if not isinstance(out, str): | |
out = out() | |
with Timer() as save_time: | |
torch.save(self.x, out, _use_new_zipfile_serialization=self.use_new) | |
if hasattr(out, 'seek'): | |
out.seek(0) | |
with Timer() as load_time: | |
torch.load(out) | |
return { | |
"save_time": save_time.ms_duration, | |
"load_time": load_time.ms_duration | |
} | |
tests = [ | |
("New to file (4 elements)", torch.ones(2, 2), 'model.zip', True), | |
("Old to file (4 elements)", torch.ones(2, 2), 'model.zip', False), | |
("New to buffer (4 elements)", torch.ones(2, 2), io.BytesIO, True), | |
("Old to buffer (4 elements)", torch.ones(2, 2), io.BytesIO, False), | |
("New to file (512^2 elements)", torch.ones(512 * 512), 'model.zip', True), | |
("Old to file (512^2 elements)", torch.ones(512 * 512), 'model.zip', False), | |
("New to buffer (512^2 elements)", torch.ones(512 * 512), io.BytesIO, True), | |
("Old to buffer (512^2 elements)", torch.ones(512 * 512), io.BytesIO, False), | |
] | |
for name, x, out, use in tests: | |
bench = Bench(x, out, use) | |
results = bench.run() | |
print(name) | |
bench.print_results(results) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment