Skip to content

Instantly share code, notes, and snippets.

@Borda
Created August 19, 2021 14:46
Show Gist options
  • Save Borda/eef87d143c5e69eb29aae19ef6e7465f to your computer and use it in GitHub Desktop.
Save Borda/eef87d143c5e69eb29aae19ef6e7465f to your computer and use it in GitHub Desktop.
fig = plt.figure(figsize=(3, 7))
for imgs, lbs in dm.val_dataloader():
# some stats about the batch - label distribution
print(f'batch labels: {torch.sum(lbs, axis=0)}')
print(f'image size: {imgs[0].shape}')
# similar as above show just first images from the batch
for i in range(3):
ax = fig.add_subplot(3, 1, i + 1, xticks=[], yticks=[])
ax.imshow(np.rollaxis(imgs[i].numpy(), 0, 3))
ax.set_title(lbs[i])
# stop after the first batch
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment