Skip to content

Instantly share code, notes, and snippets.

@abhaikollara
Last active March 25, 2017 08:19
Show Gist options
  • Save abhaikollara/3411296b5c575ac012308307a4d6be1b to your computer and use it in GitHub Desktop.
Save abhaikollara/3411296b5c575ac012308307a4d6be1b to your computer and use it in GitHub Desktop.
Bug in Keras batch_norm
# Script for reproducing a BatchNormalization bug
# https://github.com/fchollet/keras/issues/5643
from keras.models import Sequential, Model
from keras.layers import Dense, BatchNormalization
import numpy as np
m1 = Sequential([
Dense(output_dim=5, input_dim=5),
BatchNormalization(),
Dense(output_dim=5),
])
m2 = Sequential([
Dense(output_dim=5, input_dim=5),
BatchNormalization(), # Without this line, this script runs to completion
Dense(output_dim=5),
])
x = np.ones((3, 5))
y = np.ones((3, 5))
m1.compile(optimizer='sgd', loss='categorical_crossentropy')
m2.compile(optimizer='sgd', loss='categorical_crossentropy')
h = m2.fit(x, y, verbose=0) # Fitting m2 before creating m3 removes the bug
m3 = Model(input=m1.inputs, output=m2(m1(m1.inputs)))
m3.compile(optimizer='sgd', loss='categorical_crossentropy')
h1 = m1.fit(x, y, verbose=0)
h2 = m2.fit(x, y, verbose=0) # This line no longer fails if m2 if fitted before
h3 = m3.fit(x, y, verbose=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment