Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save elgehelge/afe367d022c012b205cc0e4a57351715 to your computer and use it in GitHub Desktop.
Save elgehelge/afe367d022c012b205cc0e4a57351715 to your computer and use it in GitHub Desktop.
Keras + generator, Minimal example
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
batch_size = 1
data_size = 1
def generator():
print('initializing generator')
values = list(range(10))
for v in values:
print(f'"{v}"')
data = np.ones((batch_size, data_size)) * v
targets = np.random.random((batch_size, 1))
yield (data, targets)
print('no more values left')
model = Sequential([Dense(10), Dense(data_size)])
model.compile(optimizer='adam', loss='mae')
model.build((batch_size, data_size))
model.summary()
output_types = (tf.float64, tf.float64)
output_shapes = (tf.TensorShape((batch_size, data_size)),
tf.TensorShape((batch_size, 1)))
dataset = tf.data.Dataset.from_generator(generator, output_types=output_types, output_shapes=output_shapes)
dataset = dataset.repeat()
model.fit(dataset, epochs=2, steps_per_epoch=6)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment