Skip to content

Instantly share code, notes, and snippets.

@mirth
Created August 23, 2020 21:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mirth/af0ceb8ef619af3a3fd5e7a6f9b89d12 to your computer and use it in GitHub Desktop.
Save mirth/af0ceb8ef619af3a3fd5e7a6f9b89d12 to your computer and use it in GitHub Desktop.
@baker.command
def run(batch_size=8, epochs=5, device='cuda'):
train_loader, val_loader = get_data_loaders(batch_size)
model = NoiseClassifier()
model = model.to(device)
model.eval()
optimizer = Adam(model.parameters())
criterion = nn.BCELoss()
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
val_metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
# ...
# Here we are attaching loggers
# Look into the repository for full code
# ...
trainer.run(train_loader, max_epochs=epochs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment