Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Last active May 20, 2021 19:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/66feb5c594595b8e6571f61a7fea601e to your computer and use it in GitHub Desktop.
Save aribornstein/66feb5c594595b8e6571f61a7fea601e to your computer and use it in GitHub Desktop.
# Load train data from the train directory
datamodule = ImageClassificationData.from_folders(train_folder="./train")
# Build the model
model = ImageClassifier(
backbone="resnet200d",
num_classes=datamodule.num_classes,
multi_label=True,
metrics=F1(num_classes=datamodule.num_classes),
serializer=Labels(genres, multi_label=True, threshold=0.25),
)
# Create the trainer.
trainer = flash.Trainer()
# Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# Predict what's on a few images!
prediction = model.predict(["data/movie_posters/predict/tt0085318.jpg"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment