Skip to content

Instantly share code, notes, and snippets.

@JacopoMangiavacchi
Created April 21, 2020 19:20
Show Gist options
  • Save JacopoMangiavacchi/50bb754b6367219588e40ed29258135f to your computer and use it in GitHub Desktop.
Save JacopoMangiavacchi/50bb754b6367219588e40ed29258135f to your computer and use it in GitHub Desktop.
public func prepareModel() {
let coremlModel = Model(version: 4,
shortDescription: "MNIST-Trainable",
author: "Jacopo Mangiavacchi",
license: "MIT",
userDefined: ["SwiftCoremltoolsVersion" : "0.0.12"]) {
Input(name: "image", shape: [1, 28, 28])
Output(name: "output", shape: [10], featureType: .float)
TrainingInput(name: "image", shape: [1, 28, 28])
TrainingInput(name: "output_true", shape: [1], featureType: .int)
NeuralNetwork(losses: [CategoricalCrossEntropy(name: "lossLayer",
input: "output",
target: "output_true")],
optimizer: Adam(learningRateDefault: 0.0001,
learningRateMax: 0.3,
miniBatchSizeDefault: 128,
miniBatchSizeRange: [128],
beta1Default: 0.9,
beta1Max: 1.0,
beta2Default: 0.999,
beta2Max: 1.0,
epsDefault: 0.00000001,
epsMax: 0.00000001),
epochDefault: UInt(self.epoch),
epochSet: [UInt(self.epoch)],
shuffle: true) {
Convolution(name: "conv1",
input: ["image"],
output: ["outConv1"],
outputChannels: 32,
kernelChannels: 1,
nGroups: 1,
kernelSize: [3, 3],
stride: [1, 1],
dilationFactor: [1, 1],
paddingType: .valid(borderAmounts: [EdgeSizes(startEdgeSize: 0, endEdgeSize: 0),
EdgeSizes(startEdgeSize: 0, endEdgeSize: 0)]),
outputShape: [],
deconvolution: false,
updatable: true)
ReLu(name: "relu1",
input: ["outConv1"],
output: ["outRelu1"])
Pooling(name: "pooling1",
input: ["outRelu1"],
output: ["outPooling1"],
poolingType: .max,
kernelSize: [2, 2],
stride: [2, 2],
paddingType: .valid(borderAmounts: [EdgeSizes(startEdgeSize: 0, endEdgeSize: 0),
EdgeSizes(startEdgeSize: 0, endEdgeSize: 0)]),
avgPoolExcludePadding: true,
globalPooling: false)
Convolution(name: "conv2",
input: ["outPooling1"],
output: ["outConv2"],
outputChannels: 32,
kernelChannels: 32,
nGroups: 1,
kernelSize: [2, 2],
stride: [1, 1],
dilationFactor: [1, 1],
paddingType: .valid(borderAmounts: [EdgeSizes(startEdgeSize: 0, endEdgeSize: 0),
EdgeSizes(startEdgeSize: 0, endEdgeSize: 0)]),
outputShape: [],
deconvolution: false,
updatable: true)
ReLu(name: "relu2",
input: ["outConv2"],
output: ["outRelu2"])
Pooling(name: "pooling2",
input: ["outRelu2"],
output: ["outPooling2"],
poolingType: .max,
kernelSize: [2, 2],
stride: [2, 2],
paddingType: .valid(borderAmounts: [EdgeSizes(startEdgeSize: 0, endEdgeSize: 0),
EdgeSizes(startEdgeSize: 0, endEdgeSize: 0)]),
avgPoolExcludePadding: true,
globalPooling: false)
Flatten(name: "flatten1",
input: ["outPooling2"],
output: ["outFlatten1"],
mode: .last)
InnerProduct(name: "hidden1",
input: ["outFlatten1"],
output: ["outHidden1"],
inputChannels: 1152,
outputChannels: 500,
updatable: true)
ReLu(name: "relu3",
input: ["outHidden1"],
output: ["outRelu3"])
InnerProduct(name: "hidden2",
input: ["outRelu3"],
output: ["outHidden2"],
inputChannels: 500,
outputChannels: 10,
updatable: true)
Softmax(name: "softmax",
input: ["outHidden2"],
output: ["output"])
}
}
let coreMLData = coremlModel.coreMLData
try! coreMLData!.write(to: coreMLModelUrl)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment