Created
February 15, 2019 21:59
-
-
Save marcrasi/06be4c85bbb96fe5c3713c0886607cd7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""split_into_batches_oom.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1UWb-URpxZgdkTICBZlxGWfJr7U35uvJk | |
""" | |
# %enableCompletion | |
import TensorFlow | |
struct DummyModel : Layer { | |
var parameters: Tensor<Float> | |
@differentiable(wrt: (self, input)) | |
func applied(to input: Tensor<Float>) -> Tensor<Float> { | |
return input + parameters | |
} | |
} | |
let parameterDim = Int32(1000) | |
var model = DummyModel(parameters: Tensor<Float>(ones: [parameterDim])) | |
// 10^9 Floats = 4 GB | |
let exampleCount = Int32(1000000) | |
let data = Tensor<Float>(ones: [exampleCount, parameterDim]) | |
extension DummyModel { | |
@differentiable(wrt: (self)) | |
func loss(input: Tensor<Float>) -> Tensor<Float> { | |
return self.applied(to: input).mean() | |
} | |
} | |
let batchSize = Int32(10000) | |
let optimizer = SGD<DummyModel, Float>(learningRate: 0.1) | |
for epoch in 0..<10 { | |
print("Doing epoch \(epoch)") | |
for batchIndex in 0..<(exampleCount/batchSize) { | |
if batchIndex % 20 == 0 { | |
print("\tDoing batch \(batchIndex)") | |
} | |
let start = batchIndex * batchSize | |
let end = start + batchSize | |
let batch = data[start..<end] | |
let (value, grads) = model.valueWithGradient { model in model.loss(input: batch) } | |
optimizer.update(&model.allDifferentiableVariables, along: grads) | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment