Skip to content

Instantly share code, notes, and snippets.

@songyuc
Last active October 19, 2022 16:23
Show Gist options
  • Save songyuc/211247ffa29bdcb0525467edbf2d568f to your computer and use it in GitHub Desktop.
Save songyuc/211247ffa29bdcb0525467edbf2d568f to your computer and use it in GitHub Desktop.
Got WARNING when using `reset()`.
import nvidia.dali.fn as fn
from nvidia.dali import pipeline_def
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from tqdm import tqdm
@pipeline_def
def coco_pipeline():
coco_root = "/local/COCO"
images, bounding_boxes, labels, image_ids = fn.readers.coco(
file_root=coco_root + "/train2017",
image_ids=True, # return image_id for validation
annotations_file=coco_root + "/annotations_trainval2017/annotations/instances_train2017.json",
skip_empty=True,
ratio=True,
ltrb=True,
random_shuffle=False,
shuffle_after_epoch=True, # data shuffling
name="Reader")
return image_ids
def main():
pipe = coco_pipeline(batch_size=4, num_threads=4, device_id=0)
# pipe.build() # only for debugging
train_loader = DALIGenericIterator(
pipe,
["image_ids", ],
reader_name="Reader")
for _ in range(2):
print("len(train_loader) = ", len(train_loader))
for data in tqdm(train_loader):
a = data[0]["image_ids"].T
print(a)
train_loader.reset()
return 0
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment