Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Last active May 20, 2021 19:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/e2770c95b614af8ef3428cbac7f98876 to your computer and use it in GitHub Desktop.
Save aribornstein/e2770c95b614af8ef3428cbac7f98876 to your computer and use it in GitHub Desktop.
# Load the data
datamodule = SemanticSegmentationData.from_folders(
train_folder="./train",
train_target_folder="./targets",
num_classes=21)
# Build the model
model = SemanticSegmentation(
backbone="torchvision/fcn_resnet50",
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True))
# Create the trainer
trainer = flash.Trainer()
# Train the model
trainer.finetune(model, datamodule=datamodule)
# Make a prediction
model.predict(["image.png"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment