Created
April 15, 2019 21:42
-
-
Save sgugger/b45e2059a41e6846fff3ddb83ebe9235 to your computer and use it in GitHub Desktop.
Bug in training loop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import TensorFlow | |
var xTrain = Tensor<Float>(randomNormal: [1024, 784]) | |
var yTrain = Tensor<Int32>(repeating: 0, shape: [1024]) | |
public struct MyModel: Layer { | |
public var layer1: Dense<Float> | |
public var layer2: Dense<Float> | |
public init(nIn: Int, nHid: Int, nOut: Int){ | |
layer1 = Dense(inputSize: nIn, outputSize: nHid, activation: relu) | |
layer2 = Dense(inputSize: nHid, outputSize: nOut) | |
} | |
@differentiable | |
public func applied(to input: Tensor<Float>, in context: Context) -> Tensor<Float> { | |
return input.sequenced(in: context, through: layer1, layer2) | |
} | |
} | |
var model = MyModel(nIn: 784, nHid: 50, nOut: 10) | |
let lr:Float = 0.5 | |
var optimizer = SGD<MyModel, Float>(learningRate: lr) | |
public struct DataBatch<Inputs: Differentiable & TensorGroup, Labels: TensorGroup>: TensorGroup { | |
public var xb: Inputs | |
public var yb: Labels | |
public init(xb: Inputs, yb: Labels){ | |
self.xb = xb | |
self.yb = yb | |
} | |
} | |
let train_ds:Dataset<DataBatch> = Dataset(elements:DataBatch(xb:xTrain, yb:yTrain)).batched(Int64(64)) | |
public func train<Opt: Optimizer, Labels:TensorGroup>( | |
_ model: inout Opt.Model, | |
on dataset: Dataset<DataBatch<Opt.Model.Input, Labels>>, | |
using optimizer: inout Opt, | |
lossFunc: @escaping @differentiable (Opt.Model.Output, @nondiff Labels) -> Tensor<Opt.Scalar> | |
) where Opt.Model.Input: TensorGroup, | |
Opt.Model.CotangentVector == Opt.Model.AllDifferentiableVariables, | |
Opt.Scalar: TensorFlowFloatingPoint | |
{ | |
let context = Context(learningPhase: .training) | |
for batch in dataset { | |
let (loss, ð›model) = model.valueWithGradient { model -> Tensor<Opt.Scalar> in | |
let pred = model.applied(to: batch.xb, in: context) | |
return lossFunc(pred, batch.yb) | |
} | |
optimizer.update(&model.allDifferentiableVariables, along: ð›model) | |
} | |
} | |
train(&model, on: train_ds, using: &optimizer, lossFunc: softmaxCrossEntropy) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment