Skip to content

Instantly share code, notes, and snippets.

@sgugger
Created April 15, 2019 21:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sgugger/b45e2059a41e6846fff3ddb83ebe9235 to your computer and use it in GitHub Desktop.
Save sgugger/b45e2059a41e6846fff3ddb83ebe9235 to your computer and use it in GitHub Desktop.
Bug in training loop
import TensorFlow
var xTrain = Tensor<Float>(randomNormal: [1024, 784])
var yTrain = Tensor<Int32>(repeating: 0, shape: [1024])
public struct MyModel: Layer {
public var layer1: Dense<Float>
public var layer2: Dense<Float>
public init(nIn: Int, nHid: Int, nOut: Int){
layer1 = Dense(inputSize: nIn, outputSize: nHid, activation: relu)
layer2 = Dense(inputSize: nHid, outputSize: nOut)
}
@differentiable
public func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> {
return input.sequenced(in: context, through: layer1, layer2)
}
}
var model = MyModel(nIn: 784, nHid: 50, nOut: 10)
let lr:Float = 0.5
var optimizer = SGD<MyModel, Float>(learningRate: lr)
public struct DataBatch<Inputs: Differentiable & TensorGroup, Labels: TensorGroup>: TensorGroup {
public var xb: Inputs
public var yb: Labels
public init(xb: Inputs, yb: Labels){
self.xb = xb
self.yb = yb
}
}
let train_ds:Dataset<DataBatch> = Dataset(elements:DataBatch(xb:xTrain, yb:yTrain)).batched(Int64(64))
public func train<Opt: Optimizer, Labels:TensorGroup>(
_ model: inout Opt.Model,
on dataset: Dataset<DataBatch<Opt.Model.Input, Labels>>,
using optimizer: inout Opt,
lossFunc: @escaping @differentiable (Opt.Model.Output, @nondiff Labels) -> Tensor<Opt.Scalar>
) where Opt.Model.Input: TensorGroup,
Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables,
Opt.Scalar: TensorFlowFloatingPoint
{
let context = Context(learningPhase: .training)
for batch in dataset {
let (loss, 𝛁model) = model.valueWithGradient { model -> Tensor<Opt.Scalar> in
let pred = model.applied(to: batch.xb, in: context)
return lossFunc(pred, batch.yb)
}
optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
}
}
train(&model, on: train_ds, using: &optimizer, lossFunc: softmaxCrossEntropy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment