Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Last active May 20, 2021 19:09
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save aribornstein/8a1267c9596c76eaa025508b34758411 to your computer and use it in GitHub Desktop.
# Load the data
datamodule = ObjectDetectionData.from_coco(
train_folder="./train",
train_ann_file="./annotations.json")
# Build the model
model = ObjectDetector(model="retinanet", num_classes=datamodule.num_classes)
# Create the trainer
trainer = flash.Trainer()
# Finetune the model
trainer.finetune(model, datamodule=datamodule)
# Save it!
trainer.save_checkpoint("object_detection_model.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment