Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Last active May 20, 2021 15:54
# Load the data
data_module = StyleTransferData.from_folders(train_folder="data/coco128/images", batch_size=4)
# Build the model
model = StyleTransfer(style_image=style_image, backbone="vgg16")
# Create the trainer and train
trainer = flash.Trainer(max_epochs=2)
trainer.fit(model, data_module)
# Save it!
trainer.save_checkpoint("style_transfer_model.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment