Last active
June 22, 2020 14:59
-
-
Save Mehdi-Amine/3b1e2ec26f569c05dd07dbd665e6e935 to your computer and use it in GitHub Desktop.
Training the confined but happy neural network.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.optim as optim | |
learning_rate = 0.01 | |
epochs = 20 | |
net = Network(input_size=3, lin1_size=7, lin2_size=2) | |
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) | |
criterion = nn.CrossEntropyLoss() | |
for epoch in range(epochs): | |
net.train() | |
for input, target in train_dl: | |
pred = net(input) | |
loss = criterion(pred, target) | |
loss.backward() | |
optimizer.step() | |
optimizer.zero_grad() | |
net.eval() | |
with torch.no_grad(): | |
train_loss = sum(criterion(net(input), target) for input, target in train_dl) / len(train_dl) | |
valid_loss = sum(criterion(net(input), target) for input, target in valid_dl) / len(valid_dl) | |
# Writing to Tensorboard | |
with train_summary_writer.as_default(): | |
summary.scalar('train-loss', train_loss, step=epoch) | |
with valid_summary_writer.as_default(): | |
summary.scalar('valid-loss', valid_loss, step=epoch) | |
print(epoch, train_loss, valid_loss) | |
''' | |
Out: | |
0 tensor(0.4353) tensor(0.4295) | |
... | |
19 tensor(0.1386) tensor(0.1409) | |
''' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment