Skip to content

Instantly share code, notes, and snippets.

@taiya
Last active September 14, 2021 16:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taiya/f24f9d93817dd866548acaf104251080 to your computer and use it in GitHub Desktop.
Save taiya/f24f9d93817dd866548acaf104251080 to your computer and use it in GitHub Desktop.
# --- Performance considerations
# https://www.tensorflow.org/datasets/performances
# --- For more advanced transformation piplines
# https://www.tensorflow.org/tfx/tutorials/transform/simple
# --- Pre-processing dataset (read this CAREFULLY)
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
# https://www.tensorflow.org/guide/data_performance#vectorizing_mapping
import time
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
# --- not sure when this is called...
# tfds.enable_progress_bar()
# tfds.disable_progress_bar()
ds = tf.data.Dataset.range(10) #< dummy [0..9] dataset
ds = ds.shuffle(len(ds), seed=1) #< shuffle each epoch
ds = ds.repeat(4) #< num epochs
ds = ds.batch(5) #< batch sizes
def map_functor(tensor):
return tensor+100
ds = ds.map(map_functor, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.cache(filename='/tmp/range10cache') #< warning: ignores changes to code!
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
ds = tfds.as_numpy(ds) #< for use in JAX (after this, cache/... not availabe)
for ibatch, batch in enumerate(ds):
print(f"batch[{ibatch+1:02}/{len(ds):02}] → {batch}")
# --- check dataset performance
tfds.benchmark(ds, batch_size=32)
tfds.benchmark(ds, batch_size=32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment