Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Last active May 20, 2021 15:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/12191205af2f6fa38e1f3ac751cbf9a3 to your computer and use it in GitHub Desktop.
Save aribornstein/12191205af2f6fa38e1f3ac751cbf9a3 to your computer and use it in GitHub Desktop.
# Load the data
data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4)
# Build the model
model = StyleTransfer(style_image=style_image, backbone="vgg16")
# Create the trainer and train
trainer = flash.Trainer(max_epochs=2)
trainer.fit(model, data_module)
# Save it!
trainer.save_checkpoint("style_transfer_model.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment