Skip to content

Instantly share code, notes, and snippets.

@bstriner
Created June 1, 2017 01:30
Show Gist options
  • Save bstriner/072bf2993cca32aadeb18e0b43833a1a to your computer and use it in GitHub Desktop.
Save bstriner/072bf2993cca32aadeb18e0b43833a1a to your computer and use it in GitHub Desktop.
import keras.backend as K
from keras.callbacks import CSVLogger
from keras.datasets import mnist
from keras.layers import Input, Lambda, Dense, Flatten, BatchNormalization, Activation
from keras.models import Model
def main():
# Both inputs and targets are `Input` tensors
input_x = Input((28, 28), name='input_x', dtype='uint8') # uint8 [0-255]
y_true = Input((1,), name='y_true', dtype='uint8') # uint8 [0-9]
# Build prediction network as usual
h = Flatten()(input_x)
h = Lambda(lambda _x: K.cast(_x, 'float32'),
output_shape=lambda _x: _x,
name='cast')(h) # cast uint8 to float32
h = BatchNormalization()(h) # normalize pixels
for i in range(3): # hidden relu and batchnorm layers
h = Dense(256)(h)
h = BatchNormalization()(h)
h = Activation('relu')(h)
y_pred = Dense(10, activation='softmax', name='y_pred')(h) # softmax output layer
# Lambda layer performs loss calculation (negative log likelihood)
loss = Lambda(lambda (_yt, _yp): -K.log(_yp[K.reshape(K.arange(K.shape(_yt)[0]), (-1, 1)), _yt] + K.epsilon()),
output_shape=lambda (_yt, _yp): _yt,
name='loss')([y_true, y_pred])
# Model `inputs` are both x and y. `outputs` is the loss.
model = Model(inputs=[input_x, y_true], outputs=[loss])
# Manually add the loss to the model
model.add_loss(K.sum(loss, axis=None))
# Compile with the loss weight set to None, so it will be omitted
model.compile('adam', loss=[None], loss_weights=[None])
# Add accuracy to the metrics
# Cannot add as a metric to compile, because metrics for skipped outputs are skipped
accuracy = K.mean(K.equal(K.argmax(y_pred, axis=1), K.flatten(y_true)))
model.metrics_names.append('accuracy')
model.metrics_tensors.append(accuracy)
# Model summary
model.summary()
# Train model
train, test = mnist.load_data()
cb = CSVLogger("mnist_training.csv")
model.fit(list(train), [None], epochs=300, batch_size=128, callbacks=[cb], validation_data=(list(test), [None]))
if __name__ == "__main__":
main()
@bstriner
Copy link
Author

bstriner commented Jun 1, 2017

How to use None to omit outputs during training, so you can use an Input as both an input and a target. Model then has no target when you train.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment