Skip to content

Instantly share code, notes, and snippets.

@sidneyarcidiacono
Created May 17, 2021 21:01
Show Gist options
  • Save sidneyarcidiacono/e55e73cf35c29a50fc9003e5c0f7fcd4 to your computer and use it in GitHub Desktop.
Save sidneyarcidiacono/e55e73cf35c29a50fc9003e5c0f7fcd4 to your computer and use it in GitHub Desktop.
Set optimizer and define train/test functions
# Set our optimizer (adam)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# Define our loss function
criterion = torch.nn.CrossEntropyLoss()
# Initialize our train function
def train():
model.train()
for data in train_loader: # Iterate in batches over the training dataset.
out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass.
loss = criterion(out, data.y) # Compute the loss.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
optimizer.zero_grad() # Clear gradients.
# Define our test function
def test(loader):
model.eval()
correct = 0
for data in loader: # Iterate in batches over the training/test dataset.
out = model(data.x, data.edge_index, data.batch)
pred = out.argmax(dim=1) # Use the class with highest probability.
correct += int((pred == data.y).sum()) # Check against ground-truth labels.
return correct / len(loader.dataset) # Derive ratio of correct predictions.
# Run for 200 epochs (range is exclusive in the upper bound)
for epoch in range(1, 201):
train()
train_acc = test(train_loader)
test_acc = test(test_loader)
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment