Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created November 3, 2021 18:55
Show Gist options
  • Save tchaton/6d30ed92a8243dc08c552bf2defbdaa6 to your computer and use it in GitHub Desktop.
Save tchaton/6d30ed92a8243dc08c552bf2defbdaa6 to your computer and use it in GitHub Desktop.
cifar_datamodule.py
from flash.image import ImageClassifier, ImageClassificationData
class CIFAR10DataModule(ImageClassificationData):
@property
def num_classes(self):
return 10
dm = CIFAR10DataModule.from_datasets(
train_dataset=train_set,
test_dataset=test_set,
train_transform=train_transforms,
test_transform=test_transforms,
# Do not forget to set `predict_transform`,
# this is what we will use for uncertainty estimation!
predict_transform=test_transforms,
batch_size=64,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment