Skip to content

Instantly share code, notes, and snippets.

@slaterb1
Created March 22, 2017 19:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save slaterb1/aa50bc79d87af2538212e50bc51fb2bb to your computer and use it in GitHub Desktop.
Save slaterb1/aa50bc79d87af2538212e50bc51fb2bb to your computer and use it in GitHub Desktop.
from keras.layers import Input, Masking, LSTM, Dense
from keras.models import Model
import numpy as np
# models created using Theano backend!
# Case1: model with return_sequences=True (output_shape = (1,10,1) )
##############################################################
input1 = Input(batch_shape=(1,10,16))
mask1 = Masking(mask_value=2.)(input1)
lstm1 = LSTM(16, return_sequences=True)(mask1)
dense_layer = Dense(1, activation='sigmoid')
dense_layer.__setattr__('supports_masking', True)
dense1 = dense_layer(lstm1)
model1 = Model(input1, dense1)
model1.compile(optimizer='adam', loss='binary_crossentropy')
# Case2: model with return_sequences=False (output_shape = (1,1) )
###############################################################
lstm2 = LSTM(16, return_sequences=False)(mask1)
dense2 = dense_layer(lstm2)
model2 = Model(input1, dense2)
model2.compile(optimizer='adam', loss='binary_crossentropy')
# initialize train data and labels
###############################################################
data = np.zeros((3,10,16))
data2 = np.ones((2,10,16))
labels_net1 = np.ones((3,10,1))
labels2_net1 = np.zeros((2,10,1))
labels_net2 = np.ones((3,1))
labels2_net2 = np.zeros((2,1))
train_data1 = np.concatenate([data, data2], axis=0)
train_labels1 = np.concatenate([labels_net1, labels2_net1], axis=0)
train_labels2 = np.concatenate([labels_net2, labels2_net2], axis=0)
# add 'masked' data to train_data
################################################################
masked_train_data = np.copy(train_data1)
masked_train_data[1,1,:] = 2
# train models
#################################################################
model1.fit(masked_train_data, train_labels1, nb_epoch=1000, batch_size=1)
model2.fit(masked_train_data, train_labels2, nb_epoch=1000, batch_size=1)
model3 = Model(input1, dense1) # want to retrain first network without masked data
model3.compile(optimizer='adam', loss='binary_crossentropy')
model3.fit(train_data1, train_labels1, nb_epoch=1000, batch_size=1)
# create test data
##################################################################
test_data1 = np.ones(1,10,16)
test_data2 = np.zeros(1,10,16)
# add 'mask' to test data
test_data1[0,3,:] = 2
test_data2[0,3,:] = 2
# predictions
##################################################################
model1_predictions1 = model1.predict(test_data1)
model1_predictions2 = model1.predict(test_data2)
model2_predictions1 = model2.predict(test_data1)
model2_predictions2 = model2.predict(test_data2)
model3_predictions1 = model3.predict(test_data1)
print(model1_predictions1)
print(model1_predictions2)
print(model2_predictions1)
print(model2_predictions2)
print(model3_predictions1)
# Glorious printouts
#####################################################################
# model1_predictions1, y_true = [0., 0., ..., 0.]
#[[[ 2.14141060e-08]
# [ 7.29542982e-10]
# [ 3.53702262e-10]
# [ 3.53702262e-10] <-- this is the masked line, output is same as previous
# [ 2.82663781e-10]
# [ 2.61021177e-10]
# [ 2.52067062e-10]
# [ 2.47708437e-10]
# [ 2.45296505e-10]
# [ 2.43848081e-10]]]
# model1_predictions2, y_true = [1., 1., ... ,1.]
#[[[ 0.99999619]
# [ 1. ]
# [ 1. ]
# [ 1. ] <-- masked line
# [ 1. ]
# [ 1. ]
# [ 1. ]
# [ 1. ]
# [ 1. ]
# [ 1. ]]]
# model2_predictions1, y_true = 0.
#[[ 1.09495701e-08]] <- runs data with mask
# model2_predictions2, y_true = 1.
#[[ 1.]] <- runs with mask
# model3_predictions3, y_true = [0., 0., ..., 0.]
#[[[ 2.14141060e-08]
# [ 7.29542982e-10]
# [ 3.53702262e-10]
# [ 3.53702262e-10] <- this is the masked line, output is same as previous
# [ 2.82663781e-10]
# [ 2.61021177e-10]
# [ 2.52067062e-10]
# [ 2.47708437e-10]
# [ 2.45296505e-10]
# [ 2.43848081e-10]]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment