Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created March 24, 2019 13:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/2686fe70b764e13eb6a0f54ecbc5012d to your computer and use it in GitHub Desktop.
Save NMZivkovic/2686fe70b764e13eb6a0f54ecbc5012d to your computer and use it in GitHub Desktop.
function createModelFunction() {
const cnn = tf.sequential();
cnn.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 5,
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
cnn.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
cnn.add(tf.layers.conv2d({
kernelSize: 5,
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
}));
cnn.add(tf.layers.maxPooling2d({poolSize: [2, 2], strides: [2, 2]}));
cnn.add(tf.layers.flatten());
cnn.add(tf.layers.dense({
units: 10,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
}));
cnn.compile({
optimizer: tf.train.adam(),
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
return cnn;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment