Last active
May 26, 2016 18:40
-
-
Save jimfleming/1c15dcda3e0d05b3947ac28fc757714a to your computer and use it in GitHub Desktop.
Minimal example of missing placeholder bug.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from keras.models import Model | |
from keras.layers import Input, Dense, RepeatVector, Activation, Flatten, Reshape | |
from keras.layers.convolutional import Convolution2D, MaxPooling2D, AveragePooling2D, UpSampling2D | |
from keras.layers.normalization import BatchNormalization | |
from keras.layers.recurrent import LSTM | |
from keras.layers.wrappers import TimeDistributed | |
def InterModel(input_shape): | |
input_ = Input(input_shape) | |
x = Dense(10)(input_) | |
x = BatchNormalization()(x) | |
output = Activation('relu')(x) | |
return Model(input_, output) | |
def StepModel(input_shape): | |
input_ = Input(input_shape) | |
x = Dense(10)(input_) | |
# vvvvvvvvvvvvvvvvvvvvvvvv | |
x = BatchNormalization()(x) # TODO: Commenting out this line will make it work | |
# ^^^^^^^^^^^^^^^^^^^^^^^^ | |
output = Activation('relu')(x) | |
return Model(input_, output) | |
# `InterModel` and `StepModel` are identical but used as different layers of `RecurrentModel`. | |
def RecurrentModel(input_shape): | |
input_ = Input(input_shape) | |
x = InterModel(input_shape)(input_) | |
x = RepeatVector(5)(x) | |
x = LSTM(10, return_sequences=True)(x) | |
output = TimeDistributed(StepModel(input_shape=(10,)))(x) | |
return Model(input_, output) | |
model = RecurrentModel(input_shape=(10,)) | |
model.compile(optimizer='rmsprop', loss='mse') | |
X_train, y_train = np.zeros(shape=(1000, 10)), np.zeros(shape=(1000, 5, 10)) | |
model.fit(X_train, y_train, batch_size=20, nb_epoch=5) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Relevant error when line 23 is uncommented: