Created
July 3, 2019 21:12
-
-
Save pbamotra/90ba7cabbdf1cc964c664c0bf85e68d4 to your computer and use it in GitHub Desktop.
DALI Post-1.6
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
eii = ExternalInputIterator(batch_size=16, | |
data_file=processed_data_file, | |
image_dir=images_directory) | |
iterator = iter(eii) | |
class ExternalSourcePipeline(Pipeline): | |
def __init__(self, data_iterator, batch_size, num_threads, device_id): | |
super(ExternalSourcePipeline, self).__init__(batch_size, | |
num_threads, | |
device_id, | |
seed=12) | |
self.data_iterator = data_iterator | |
self.input = ops.ExternalSource() | |
self.input_label = ops.ExternalSource() | |
self.decode = ops.nvJPEGDecoder(device = "mixed", output_type = types.RGB) | |
# resizing is *must* because loaded images maybe of different sizes | |
# and to create GPU tensors we need image arrays to be of same size | |
self.res = ops.Resize(device="gpu", resize_x=224, resize_y=224, interp_type=types.INTERP_TRIANGULAR) | |
def define_graph(self): | |
self.jpegs = self.input() | |
self.labels = self.input_label() | |
images = self.decode(self.jpegs) | |
output = self.res(images) | |
return (output, self.labels) | |
def iter_setup(self): | |
# the external data iterator is consumed here and fed as input to Pipeline | |
images, labels = self.data_iterator.next() | |
self.feed_input(self.jpegs, images) | |
self.feed_input(self.labels, labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment