Skip to content

Instantly share code, notes, and snippets.

@Hironsan
Created September 9, 2017 04:03
Show Gist options
  • Save Hironsan/e041d6606164bc14c50aa56b989c5fc0 to your computer and use it in GitHub Desktop.
Save Hironsan/e041d6606164bc14c50aa56b989c5fc0 to your computer and use it in GitHub Desktop.
fit vs fit_generator in Keras
import numpy as np
from keras.preprocessing import sequence
from keras.models import Sequential
from keras.layers import Dense, Embedding
from keras.layers import LSTM
from keras.datasets import imdb
def batch_iter(data, labels, batch_size, shuffle=True):
num_batches_per_epoch = int((len(data) - 1) / batch_size) + 1
def data_generator():
data_size = len(data)
while True:
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
shuffled_labels = labels[shuffle_indices]
else:
shuffled_data = data
shuffled_labels = labels
for batch_num in range(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
X, y = shuffled_data[start_index: end_index], shuffled_labels[start_index: end_index]
yield X, y
return num_batches_per_epoch, data_generator()
def main(mode):
max_features = 20000
maxlen = 80
batch_size = 32
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
model = Sequential()
model.add(Embedding(max_features, 128))
model.add(LSTM(128, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
if mode == 'fit':
model.fit(x_train, y_train, batch_size=batch_size, epochs=1, validation_data=(x_test, y_test))
else:
train_steps, train_batches = batch_iter(x_train, y_train, batch_size)
valid_steps, valid_batches = batch_iter(x_test, y_test, batch_size)
model.fit_generator(train_batches, train_steps, epochs=1, validation_data=valid_batches, validation_steps=valid_steps)
if __name__ == '__main__':
import sys
mode = sys.argv[1]
main(mode)
@leferrad
Copy link

Great code! Simple and very useful! One question regarding line 10. I wonder why this line is not like this:

num_batches_per_epoch = int(len(data) / float(batch_size))

I guess you tried to avoid the clip done in a division between integers, that can be achieved by casting the denominator as float. Or there is another reason that I'm missing?
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment