Skip to content

Instantly share code, notes, and snippets.

@dbolella
Last active November 11, 2019 17:25
Show Gist options
  • Save dbolella/ccdc094e8e0a2f3eaf94c3b6b0722d67 to your computer and use it in GitHub Desktop.
Save dbolella/ccdc094e8e0a2f3eaf94c3b6b0722d67 to your computer and use it in GitHub Desktop.
Training Loop for LesNet-5 MNIST Model in S4TF
// The training loop.
for epoch in 1...epochCount {
var trainStats = Statistics()
var testStats = Statistics()
Context.local.learningPhase = .training
for i in 0 ..< dataset.trainingSize / batchSize {
let images = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
let labels = dataset.trainingLabels.minibatch(at: i, batchSize: batchSize)
// Compute the gradient with respect to the model.
let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor<Float> in
let logits = model(images)
trainStats.updateGuessCounts(logits: logits, labels: labels, batchSize: batchSize)
return softmaxCrossEntropy(logits: logits, labels: labels)
}
trainStats.totalLoss += loss.scalarized()
optimizer.update(&model, along: gradients)
}
Context.local.learningPhase = .inference
for i in 0 ..< dataset.testSize / batchSize {
let images = dataset.testImages.minibatch(at: i, batchSize: batchSize)
let labels = dataset.testLabels.minibatch(at: i, batchSize: batchSize)
// Compute loss on test set
let logits = model(images)
testStats.updateGuessCounts(logits: logits, labels: labels, batchSize: batchSize)
let loss = softmaxCrossEntropy(logits: logits, labels: labels)
testStats.totalLoss += loss.scalarized()
}
let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
print("""
[Epoch \(epoch)] \
Training Loss: \(trainStats.totalLoss), \
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
(\(trainAccuracy)), \
Test Loss: \(testStats.totalLoss), \
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
(\(testAccuracy))
""")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment