Skip to content

Instantly share code, notes, and snippets.

@pwinston
Created July 10, 2020 20:54
Show Gist options
  • Save pwinston/febb9ad3ae6da11be76d77743ebe9e0f to your computer and use it in GitHub Desktop.
Save pwinston/febb9ad3ae6da11be76d77743ebe9e0f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import contextlib
import time
import numpy as np
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import Dataset
from torch import randn
from napari.utils import resize_dask_cache
import dask
import dask.array as da
resize_dask_cache(0)
@contextlib.contextmanager
def perf_timer(name: str):
start_ns = time.perf_counter_ns()
yield
end_ns = time.perf_counter_ns()
ms = (end_ns - start_ns) / 1e6
print(f"{name} {ms}ms")
mnist_test = MNIST(
'../data/MNIST',
download=True,
transform=transforms.Compose([transforms.ToTensor(),]),
train=False,
)
def add_noise(img):
return img + randn(img.size()) * 0.4
class SyntheticNoiseDataset(Dataset):
def __init__(self, data, mode='train'):
self.mode = mode
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
img = self.data[index][0]
return add_noise(img), img
noisy_mnist_test = SyntheticNoiseDataset(mnist_test, 'test')
def make_stack(count):
return da.stack(
[
da.from_delayed(
dask.delayed(
lambda i: noisy_mnist_test[i][1].detach().numpy()
)(i),
shape=(1, 28, 28),
dtype=np.float32,
).reshape((28, 28))
for i in range(count)
]
)
def test_access(label, stack):
print(f"Accessing {label}")
for i in range(3):
with perf_timer(f"stack[{i}] = "):
np.asarray(stack[0])
stack_10 = make_stack(10)
stack_100 = make_stack(100)
stack_1000 = make_stack(1000)
stack_10000 = make_stack(10000)
test_access("stack_10", stack_10)
test_access("stack_100", stack_100)
test_access("stack_1000", stack_1000)
test_access("stack_10000", stack_10000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment