Skip to content

Instantly share code, notes, and snippets.

@HappyStorm
Created April 17, 2017 18:23
Show Gist options
  • Save HappyStorm/cb6c22ffec18a8fbb4912e9c79b6d87c to your computer and use it in GitHub Desktop.
Save HappyStorm/cb6c22ffec18a8fbb4912e9c79b6d87c to your computer and use it in GitHub Desktop.
Self-defined generator
# definition of generator
def generator_from_ndarray(x, y, batch_size=32):
number_of_batches = np.ceil(x.shape[0]/batch_size).astype(int)
while True:
for i in range(number_of_batches):
yield x[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
# x_train.shape = (5000, 100, 600, 1)
# y_train.shape = (5000, 1)
# x_valid.shape = (1000, 100, 600, 1)
# y_valid.shape = (1000, 1)
# the part of my fit_generator code
history = model.fit_generator(
generator=generator_from_ndarray(x_train, y_train, 20),
steps_per_epoch=x_train.shape[0],
epochs=100, verbose=1, callbacks=callbacks,
validation_data=generator_from_ndarray(x_valid, y_valid, 20),
validation_steps=x_valid.shape[0], max_q_size=40,
workers=2, initial_epoch=init_epoch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment