Skip to content

Instantly share code, notes, and snippets.

@ragulpr
Last active March 20, 2017 22:13
Show Gist options
  • Save ragulpr/4353ff31a665021ab2d91de84ec20348 to your computer and use it in GitHub Desktop.
Save ragulpr/4353ff31a665021ab2d91de84ec20348 to your computer and use it in GitHub Desktop.
Keras masking layer performance
import keras
from keras.layers import *
from keras.models import Model
import theano as T
import tensorflow as tf
print('theano ver.',T.__version__)
print('tensorflow ver.',tf.__version__)
print('keras ver.',keras.__version__)
np.set_printoptions(precision=4)
np.random.seed(1)
from keras.models import Sequential
import timeit
def lstm_masked():
model = Sequential()
model.add(Masking(mask_value = mask_value,input_shape=(n_timesteps, n_features)))
model.add(LSTM(20, return_sequences=True))
model.add(LSTM(20, return_sequences=True))
output_val = model.predict(data)
return output_val
def lstm_nonmasked():
model = Sequential()
model.add(LSTM(20, return_sequences=True, input_shape=(n_timesteps, n_features)))
model.add(LSTM(20, return_sequences=True))
output_val = model.predict(data)
return output_val
n_samples = 1000
n_timesteps = 1000
n_features = 20
mask_value = -1
data = np.random.uniform(0,1,(n_samples,n_timesteps,n_features))
data[:,:(n_timesteps/2),:] = mask_value
print('Time without: ',timeit.timeit(lstm_nonmasked, number=2))
print('Time with : ',timeit.timeit(lstm_masked, number=2))
data = np.random.uniform(0,1,(n_samples,n_timesteps,n_features))
data[:,:(n_timesteps/2),:] = mask_value
print('Time with : ',timeit.timeit(lstm_masked, number=2))
print('Time without: ',timeit.timeit(lstm_nonmasked, number=2))
# TENSORFLOW BACKEND
# ('tensorflow ver.', '1.0.1')
# ('keras ver.', '2.0.1')
# ('Time without: ', 90.77176785469055)
# ('Time with : ', 89.05559015274048)
# ('Time with : ', 91.29713892936707)
# ('Time without: ', 90.39765214920044)
# KERAS BACKEND
# ('theano ver.', '0.8.2')
# ('Time without: ', 16.689538955688477)
# ('Time with : ', 20.254252910614014)
# ('Time with : ', 15.486266136169434)
# ('Time without: ', 13.057673215866089)
# Can someone scale this up?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment