Skip to content

Instantly share code, notes, and snippets.

@Seanny123
Forked from ragulpr/py
Last active May 29, 2017 08:49
Show Gist options
  • Save Seanny123/1d63e493686dac41f4bcfe8f9f0aae81 to your computer and use it in GitHub Desktop.
Save Seanny123/1d63e493686dac41f4bcfe8f9f0aae81 to your computer and use it in GitHub Desktop.
Keras masking example
from keras.layers import Masking, Dense
from keras.layers.recurrent import LSTM
from keras.models import Sequential
import numpy as np
np.set_printoptions(precision=4)
np.random.seed(1)
def print_output_val(out_val):
print("\n'--> time'")
np.linspace(0, n_timesteps - 1, n_timesteps)
for sample in range(n_samples):
print('# sample = ', sample)
print('input:')
print(data[sample, :, :].T)
print('output_val:')
print(out_val[sample, :, :].T)
def sequential_non_temporal_example():
model = Sequential()
model.add(Masking(mask_value=mask_value, input_shape=(n_timesteps, n_features)))
model.add(Dense(1, activation='linear', kernel_initializer="one"))
output_val = model.predict(data)
print_output_val(output_val)
def sequential_temporal_example():
model = Sequential()
model.add(Masking(mask_value=mask_value, input_shape=(n_timesteps, n_features)))
model.add(LSTM(2, return_sequences=True, kernel_initializer="one"))
model.add(Dense(1, activation='linear', kernel_initializer="one"))
output_val = model.predict(data)
print_output_val(output_val)
n_samples = 3
n_timesteps = 7
n_features = 2
mask_value = np.NaN # -999999999.0# -1.0 # -1 # 0.0
data = np.ones((n_samples, n_timesteps, n_features))
for s in range(n_samples):
for f in range(n_features):
data[s, :, f] = np.linspace(1, n_timesteps, n_timesteps)
# mask a feature value of one sample and timestep (no effect)
data[1, 0, 0] = mask_value
# mask all feature values of one sample and timestep (propagates 0.*mask_value at layer of step/sample?)
data[2, 3, :] = mask_value
print('####################### sequential_non_temporal_example #######################:')
sequential_non_temporal_example()
print('####################### sequential_temporal_example #######################:')
# As non-temporal but masked timestep state does not propagate through time:
sequential_temporal_example()
# ####################### sequential_non_temporal_example #######################:
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# masking_1 (Masking) (None, 7, 2) 0
# _________________________________________________________________
# dense_1 (Dense) (None, 7, 1) 3
# =================================================================
# Total params: 3.0
# Trainable params: 3.0
# Non-trainable params: 0.0
# _________________________________________________________________
# --> time
# [ 0. 1. 2. 3. 4. 5. 6.]
# # sample = 0
# input:
# [[ 1. 2. 3. 4. 5. 6. 7.]
# [ 1. 2. 3. 4. 5. 6. 7.]]
# output_val:
# [[ 2. 4. 6. 8. 10. 12. 14.]]
# # sample = 1
# input:
# [[ nan 2. 3. 4. 5. 6. 7.]
# [ 1. 2. 3. 4. 5. 6. 7.]]
# output_val:
# [[ nan 4. 6. 8. 10. 12. 14.]]
# # sample = 2
# input:
# [[ 1. 2. 3. nan 5. 6. 7.]
# [ 1. 2. 3. nan 5. 6. 7.]]
# output_val:
# [[ 2. 4. 6. nan 10. 12. 14.]]
# ####################### sequential_temporal_example #######################:
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# masking_2 (Masking) (None, 7, 2) 0
# _________________________________________________________________
# lstm_1 (LSTM) (None, 7, 2) 40
# _________________________________________________________________
# dense_2 (Dense) (None, 7, 1) 3
# =================================================================
# Total params: 43.0
# Trainable params: 43.0
# Non-trainable params: 0.0
# _________________________________________________________________
# --> time
# [ 0. 1. 2. 3. 4. 5. 6.]
# # sample = 0
# input:
# [[ 1. 2. 3. 4. 5. 6. 7.]
# [ 1. 2. 3. 4. 5. 6. 7.]]
# output_val:
# [[ 1.2603 1.9066 1.9871 1.9982 1.9998 2. 2. ]]
# # sample = 1
# input:
# [[ nan 2. 3. 4. 5. 6. 7.]
# [ 1. 2. 3. 4. 5. 6. 7.]]
# output_val:
# [[ nan nan nan nan nan nan nan]]
# # sample = 2
# input:
# [[ 1. 2. 3. nan 5. 6. 7.]
# [ 1. 2. 3. nan 5. 6. 7.]]
# output_val:
# [[ 1.2603 1.9066 1.9871 nan nan nan nan]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment