Last active
June 19, 2021 10:41
-
-
Save nizhib/28c1b3ced8f60620a38dcd8fc5d919b3 to your computer and use it in GitHub Desktop.
Multi-Image Data Prefetcher
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DataPrefetcher: | |
def __init__(self, loader, num_images=1, num_labels=1): | |
self.loader = iter(loader) | |
self.num_images = num_images | |
self.num_labels = num_labels | |
self.num_others = 1 | |
self.stream = torch.cuda.Stream() | |
self.mean = ( | |
torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]) | |
.cuda() | |
.view(1, 3, 1, 1) | |
) | |
self.std = ( | |
torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]) | |
.cuda() | |
.view(1, 3, 1, 1) | |
) | |
self.next_images = [None] * self.num_images | |
self.next_labels = [None] * self.num_labels | |
self.next_others = [None] * self.num_others | |
self.preload() | |
def preload(self): | |
try: | |
item = next(self.loader) | |
for i in range(self.num_images): | |
self.next_images[i] = item[i] | |
for j in range(self.num_labels): | |
self.next_labels[j] = item[self.num_images + j] | |
[*self.next_others] = item[self.num_images + self.num_labels :] | |
self.num_others = len(self.next_others) | |
except StopIteration: | |
self.next_images = [None] * self.num_images | |
self.next_labels = [None] * self.num_labels | |
self.next_others = [None] * self.num_others | |
return | |
with torch.cuda.stream(self.stream): | |
for i in range(self.num_images): | |
self.next_images[i] = self.next_images[i].cuda(non_blocking=True) | |
for j in range(self.num_labels): | |
self.next_labels[j] = self.next_labels[j].cuda(non_blocking=True) | |
for i in range(self.num_images): | |
self.next_images[i] = self.next_images[i].float() | |
self.next_images[i] = self.next_images[i].sub_(self.mean).div_(self.std) | |
for j in range(self.num_labels): | |
self.next_labels[j] = self.next_labels[j].float() | |
def next(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
images = [image for image in self.next_images] | |
labels = [label for label in self.next_labels] | |
others = self.next_others | |
if images[0] is not None: | |
for image in images: | |
image.record_stream(torch.cuda.current_stream()) | |
if labels[0] is not None: | |
for label in labels: | |
label.record_stream(torch.cuda.current_stream()) | |
self.preload() | |
return (*images, *labels, *others) | |
def fast_collate(batch): | |
num_images = 0 | |
while ( | |
num_images < len(batch[0]) | |
and hasattr(batch[0][num_images], "shape") | |
and len(batch[0][num_images].shape) >= 2 | |
): | |
num_images += 1 | |
images = [[_[i] for _ in batch] for i in range(num_images)] | |
others = [_[num_images:] for _ in batch] | |
b = len(images[0]) | |
h = images[0][0].shape[0] | |
w = images[0][0].shape[1] | |
c = [ | |
images[i][0].shape[2] if images[i][0].ndim == 3 else 1 | |
for i in range(num_images) | |
] | |
images_th = [ | |
torch.zeros((b, c[i], h, w), dtype=torch.uint8) | |
for i in range(num_images) | |
] | |
for i in range(len(images_th)): | |
for j in range(b): | |
images_th[i][j] += to_tensor(images[i][j]) | |
return (*images_th, *map(list, zip(*others))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment