Created
May 17, 2021 21:01
-
-
Save sidneyarcidiacono/e55e73cf35c29a50fc9003e5c0f7fcd4 to your computer and use it in GitHub Desktop.
Set optimizer and define train/test functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Set our optimizer (adam) | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) | |
# Define our loss function | |
criterion = torch.nn.CrossEntropyLoss() | |
# Initialize our train function | |
def train(): | |
model.train() | |
for data in train_loader: # Iterate in batches over the training dataset. | |
out = model(data.x, data.edge_index, data.batch) # Perform a single forward pass. | |
loss = criterion(out, data.y) # Compute the loss. | |
loss.backward() # Derive gradients. | |
optimizer.step() # Update parameters based on gradients. | |
optimizer.zero_grad() # Clear gradients. | |
# Define our test function | |
def test(loader): | |
model.eval() | |
correct = 0 | |
for data in loader: # Iterate in batches over the training/test dataset. | |
out = model(data.x, data.edge_index, data.batch) | |
pred = out.argmax(dim=1) # Use the class with highest probability. | |
correct += int((pred == data.y).sum()) # Check against ground-truth labels. | |
return correct / len(loader.dataset) # Derive ratio of correct predictions. | |
# Run for 200 epochs (range is exclusive in the upper bound) | |
for epoch in range(1, 201): | |
train() | |
train_acc = test(train_loader) | |
test_acc = test(test_loader) | |
print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment