Skip to content

Instantly share code, notes, and snippets.

@nizhib
Last active June 19, 2021 10:41
Show Gist options
  • Save nizhib/28c1b3ced8f60620a38dcd8fc5d919b3 to your computer and use it in GitHub Desktop.
Save nizhib/28c1b3ced8f60620a38dcd8fc5d919b3 to your computer and use it in GitHub Desktop.
Multi-Image Data Prefetcher
class DataPrefetcher:
def __init__(self, loader, num_images=1, num_labels=1):
self.loader = iter(loader)
self.num_images = num_images
self.num_labels = num_labels
self.num_others = 1
self.stream = torch.cuda.Stream()
self.mean = (
torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255])
.cuda()
.view(1, 3, 1, 1)
)
self.std = (
torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255])
.cuda()
.view(1, 3, 1, 1)
)
self.next_images = [None] * self.num_images
self.next_labels = [None] * self.num_labels
self.next_others = [None] * self.num_others
self.preload()
def preload(self):
try:
item = next(self.loader)
for i in range(self.num_images):
self.next_images[i] = item[i]
for j in range(self.num_labels):
self.next_labels[j] = item[self.num_images + j]
[*self.next_others] = item[self.num_images + self.num_labels :]
self.num_others = len(self.next_others)
except StopIteration:
self.next_images = [None] * self.num_images
self.next_labels = [None] * self.num_labels
self.next_others = [None] * self.num_others
return
with torch.cuda.stream(self.stream):
for i in range(self.num_images):
self.next_images[i] = self.next_images[i].cuda(non_blocking=True)
for j in range(self.num_labels):
self.next_labels[j] = self.next_labels[j].cuda(non_blocking=True)
for i in range(self.num_images):
self.next_images[i] = self.next_images[i].float()
self.next_images[i] = self.next_images[i].sub_(self.mean).div_(self.std)
for j in range(self.num_labels):
self.next_labels[j] = self.next_labels[j].float()
def next(self):
torch.cuda.current_stream().wait_stream(self.stream)
images = [image for image in self.next_images]
labels = [label for label in self.next_labels]
others = self.next_others
if images[0] is not None:
for image in images:
image.record_stream(torch.cuda.current_stream())
if labels[0] is not None:
for label in labels:
label.record_stream(torch.cuda.current_stream())
self.preload()
return (*images, *labels, *others)
def fast_collate(batch):
num_images = 0
while (
num_images < len(batch[0])
and hasattr(batch[0][num_images], "shape")
and len(batch[0][num_images].shape) >= 2
):
num_images += 1
images = [[_[i] for _ in batch] for i in range(num_images)]
others = [_[num_images:] for _ in batch]
b = len(images[0])
h = images[0][0].shape[0]
w = images[0][0].shape[1]
c = [
images[i][0].shape[2] if images[i][0].ndim == 3 else 1
for i in range(num_images)
]
images_th = [
torch.zeros((b, c[i], h, w), dtype=torch.uint8)
for i in range(num_images)
]
for i in range(len(images_th)):
for j in range(b):
images_th[i][j] += to_tensor(images[i][j])
return (*images_th, *map(list, zip(*others)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment