Skip to content

Instantly share code, notes, and snippets.

View aribornstein's full-sized avatar

PythicCoder aribornstein

View GitHub Profile
# Export To Onnx
model.to_onnx(filepath, input_sample, export_params=True)
# Export to Torch Script
torch.jit.save(model.to_torchscript(), "model.pt")
Trainer(num_gpus=32).predict(millions_of_reviews)
prediction = model.predict("This movie is great!")
# Save Checkpoint
trainer.save_checkpoint("text_class_model.pt")
# Load Model From Checkpoint
model = TextClassifier.load_from_checkpoint("text_class_model.pt")
trainer = flash.Trainer(gpus=8)
trainer = flash.Trainer(gpus=8, num_nodes=32)
trainer = flash.Trainer(tpu_cores=1)
trainer = flash.Trainer(max_epochs = 1)
trainer.finetune(model, datamodule = datamodule)
model = TextClassifier(num_classes = 2, backbone = 'roberta-base')
datamodule = TextClassificationData.from_files(
train_file="data/imdb/train.csv",
valid_file="data/imdb/valid.csv",
test_file="data/imdb/test.csv",
input="review",
target="sentiment"
)
download_data('https://pl-flash-data.s3.amazonaws.com/imdb.zip', 'data/')