-
-
Save ragulpr/4353ff31a665021ab2d91de84ec20348 to your computer and use it in GitHub Desktop.
Keras masking layer performance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras | |
from keras.layers import * | |
from keras.models import Model | |
import theano as T | |
import tensorflow as tf | |
print('theano ver.',T.__version__) | |
print('tensorflow ver.',tf.__version__) | |
print('keras ver.',keras.__version__) | |
np.set_printoptions(precision=4) | |
np.random.seed(1) | |
from keras.models import Sequential | |
import timeit | |
def lstm_masked(): | |
model = Sequential() | |
model.add(Masking(mask_value = mask_value,input_shape=(n_timesteps, n_features))) | |
model.add(LSTM(20, return_sequences=True)) | |
model.add(LSTM(20, return_sequences=True)) | |
output_val = model.predict(data) | |
return output_val | |
def lstm_nonmasked(): | |
model = Sequential() | |
model.add(LSTM(20, return_sequences=True, input_shape=(n_timesteps, n_features))) | |
model.add(LSTM(20, return_sequences=True)) | |
output_val = model.predict(data) | |
return output_val | |
n_samples = 1000 | |
n_timesteps = 1000 | |
n_features = 20 | |
mask_value = -1 | |
data = np.random.uniform(0,1,(n_samples,n_timesteps,n_features)) | |
data[:,:(n_timesteps/2),:] = mask_value | |
print('Time without: ',timeit.timeit(lstm_nonmasked, number=2)) | |
print('Time with : ',timeit.timeit(lstm_masked, number=2)) | |
data = np.random.uniform(0,1,(n_samples,n_timesteps,n_features)) | |
data[:,:(n_timesteps/2),:] = mask_value | |
print('Time with : ',timeit.timeit(lstm_masked, number=2)) | |
print('Time without: ',timeit.timeit(lstm_nonmasked, number=2)) | |
# TENSORFLOW BACKEND | |
# ('tensorflow ver.', '1.0.1') | |
# ('keras ver.', '2.0.1') | |
# ('Time without: ', 90.77176785469055) | |
# ('Time with : ', 89.05559015274048) | |
# ('Time with : ', 91.29713892936707) | |
# ('Time without: ', 90.39765214920044) | |
# KERAS BACKEND | |
# ('theano ver.', '0.8.2') | |
# ('Time without: ', 16.689538955688477) | |
# ('Time with : ', 20.254252910614014) | |
# ('Time with : ', 15.486266136169434) | |
# ('Time without: ', 13.057673215866089) | |
# Can someone scale this up? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment