Created
July 3, 2019 21:13
-
-
Save pbamotra/2d20c073899ca2cb8ff58b0ed911b771 to your computer and use it in GitHub Desktop.
DALI Post-1.7
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nvidia.dali.plugin.pytorch import DALIGenericIterator | |
pipe = ExternalSourcePipeline(data_iterator=iterator, batch_size=16, num_threads=2, device_id=0) | |
pipe.build() | |
# first parameter is list of pipelines to run | |
# second pipeline is output_map that maps consecutive outputs to | |
# corresponding names | |
# last parameter is the number of iterations - number of examples you | |
# want to iterate on | |
dali_iter = DALIGenericIterator([pipe], ['images', 'labels'], 256) | |
for i, it in enumerate(dali_iter): | |
batch_data = it[0] | |
images, labels = batch_data["images"], batch_data["labels"] | |
# both images and labels are `torch.Tensor` which can now be processed | |
# the way we usually do in Pytorch example -https://urlzs.com/Wa2b | |
# the rest of the code in this block looks something like | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
outputs = net(inputs) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment