Skip to content

Instantly share code, notes, and snippets.

@kedevked
Created March 8, 2019 21:38
Show Gist options
  • Save kedevked/17912ec2e9e542920276d8a5d761e433 to your computer and use it in GitHub Desktop.
Save kedevked/17912ec2e9e542920276d8a5d761e433 to your computer and use it in GitHub Desktop.
build a sequential model for transfert learning
async function buildModel(featureExtractor, units) {
return tf.sequential({
layers: [
// Flattens the input to a vector so we can use it in a dense layer. While
// technically a layer, this only performs a reshape (and has no training
// parameters).
// slice so as not to take the batch size
tf.layers.flatten(
{ inputShape: featureExtractor.outputs[0].shape.slice(1) }),
// add all the layers of the model to train
tf.layers.dense({
units: units,
activation: 'relu',
kernelInitializer: 'varianceScaling',
useBias: true
}),
// Layer 2. The number of units of the last layer should correspond
// to the number of classes we want to predict.
tf.layers.dense({
units: NUM_CLASSES,
kernelInitializer: 'varianceScaling',
useBias: false,
activation: 'softmax'
})
]
});
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment