Skip to content

Instantly share code, notes, and snippets.

@rcdexta
Created August 4, 2021 05:22
Show Gist options
  • Save rcdexta/5afd34f026888e5de95f80fa80cf5345 to your computer and use it in GitHub Desktop.
Save rcdexta/5afd34f026888e5de95f80fa80cf5345 to your computer and use it in GitHub Desktop.
PyTorch Simple Neural Network
# Hyperparameters for our network
input_size = 784
hidden_sizes = [128, 64]
output_size = 10
# Build a feed-forward network
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.Softmax(dim=1))
print(model)
# Forward pass through the network and display output
images, labels = next(iter(trainloader))
images.resize_(images.shape[0], 1, 784)
ps = model.forward(images[0,:])
helper.view_classify(images[0].view(1, 28, 28), ps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment