Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save JacopoMangiavacchi/a68dd1c19f6c0a915f052cf58ad20db3 to your computer and use it in GitHub Desktop.
Save JacopoMangiavacchi/a68dd1c19f6c0a915f052cf58ad20db3 to your computer and use it in GitHub Desktop.
private func buildTrainingGraph() {
trainingGraph = MLCTrainingGraph(graphObjects: [graph],
lossLayer: MLCLossLayer(descriptor: MLCLossDescriptor(type: .softmaxCrossEntropy,
reductionType: .mean)),
optimizer: MLCAdamOptimizer(descriptor: MLCOptimizerDescriptor(learningRate: 0.001,
gradientRescale: 1.0,
regularizationType: .none,
regularizationScale: 0.0),
beta1: 0.9,
beta2: 0.999,
epsilon: 1e-7,
timeStep: 1))
trainingGraph.addInputs(["image" : inputTensor],
lossLabels: ["label" : lossLabelTensor])
trainingGraph.compile(options: [], device: device)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment