Skip to content

Instantly share code, notes, and snippets.

@jdevoldere
Created July 6, 2020 17:10
Show Gist options
  • Save jdevoldere/84e08d25fe8fde43e64ca5cafdef100c to your computer and use it in GitHub Desktop.
Save jdevoldere/84e08d25fe8fde43e64ca5cafdef100c to your computer and use it in GitHub Desktop.
Multiprocessing batch generator in Keras
import multiprocessing
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, LSTM
from tensorflow.keras.models import Model
import time
def batch_generator(queue):
while True:
# randomly generated dummy data
x_batch = np.random.rand(48, 24, 6)
y_batch = np.random.randint(2, size=48)
time.sleep(1) # simulate heavy calculations
queue.put([x_batch, y_batch])
def batch_retriever():
while True:
x_batch, y_batch = q.get()
yield x_batch, y_batch
if __name__ == '__main__':
# run batch generator
q = multiprocessing.Queue(maxsize=23)
generator_count = 5
for x in range(generator_count):
generator = multiprocessing.Process(target=batch_generator, args=(q,))
generator.start()
# define model
i = Input(batch_shape=(None, 24, 6))
o = LSTM(128, return_sequences=True)(i)
o = LSTM(128, return_sequences=False)(o)
o = Dense(1, activation='sigmoid')(o)
m = Model(inputs=[i], outputs=[o])
m.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss="binary_crossentropy",
metrics=["binary_accuracy"])
# train model
gen = batch_retriever()
m.fit(gen, epochs=10, steps_per_epoch=100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment