Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Last active May 20, 2021 19:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/1546f4278cc8c5598cc0b8fd94a5e9fc to your computer and use it in GitHub Desktop.
Save aribornstein/1546f4278cc8c5598cc0b8fd94a5e9fc to your computer and use it in GitHub Desktop.
# Load the data from directories
datamodule = VideoClassificationData.from_folders(
train_folder="./train",
clip_sampler="uniform",
clip_duration=1,
video_sampler=RandomSampler,
decode_audio=False)
# Build the VideoClassifier with a PyTorchVideo backbone.
model = VideoClassifier(
backbone="x3d_xs",
num_classes=datamodule.num_classes,
serializer=Labels())
# Finetune the model
trainer = flash.Trainer()
trainer.finetune(model, datamodule=datamodule)
# Make a prediction
model.predict("./videos_dir")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment