Skip to content

Instantly share code, notes, and snippets.

@dschwertfeger
Last active January 10, 2022 10:17
Show Gist options
  • Save dschwertfeger/64d853d83f52292c499a3bf293127f7c to your computer and use it in GitHub Desktop.
Save dschwertfeger/64d853d83f52292c499a3bf293127f7c to your computer and use it in GitHub Desktop.
A data-pipeline using TFRecords
import os
import tensorflow as tf
AUTOTUNE = tf.data.experimental.AUTOTUNE
def _parse_batch(record_batch, sample_rate, duration):
n_samples = sample_rate * duration
# Create a description of the features
feature_description = {
'audio': tf.io.FixedLenFeature([n_samples], tf.float32),
'label': tf.io.FixedLenFeature([1], tf.int64),
}
# Parse the input `tf.Example` proto using the dictionary above
example = tf.io.parse_example(record_batch, feature_description)
return example['audio'], example['label']
def get_dataset_from_tfrecords(tfrecords_dir='tfrecords', split='train',
batch_size=64, sample_rate=22050, duration=4,
n_epochs=10):
if split not in ('train', 'test', 'validate'):
raise ValueError("split must be either 'train', 'test' or 'validate'")
# List all *.tfrecord files for the selected split
pattern = os.path.join(tfrecords_dir, '{}*.tfrecord'.format(split))
files_ds = tf.data.Dataset.list_files(pattern)
# Disregard data order in favor of reading speed
ignore_order = tf.data.Options()
ignore_order.experimental_deterministic = False
files_ds = files_ds.with_options(ignore_order)
# Read TFRecord files in an interleaved order
ds = tf.data.TFRecordDataset(files_ds,
compression_type='ZLIB',
num_parallel_reads=AUTOTUNE)
# Prepare batches
ds = ds.batch(batch_size)
# Parse a batch into a dataset of [audio, label] pairs
ds = ds.map(lambda x: _parse_batch(x, sample_rate, duration))
# Repeat the training data for n_epochs. Don't repeat test/validate splits.
if split == 'train':
ds = ds.repeat(n_epochs)
return ds.prefetch(buffer_size=AUTOTUNE)
def main():
train_ds = get_dataset_from_tfrecords()
model = tf.keras.models.load_model('model.h5')
model.fit(train_ds, epochs=10)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment