-
-
Save taiya/f24f9d93817dd866548acaf104251080 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# --- Performance considerations | |
# https://www.tensorflow.org/datasets/performances | |
# --- For more advanced transformation piplines | |
# https://www.tensorflow.org/tfx/tutorials/transform/simple | |
# --- Pre-processing dataset (read this CAREFULLY) | |
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map | |
# https://www.tensorflow.org/guide/data_performance#vectorizing_mapping | |
import time | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
# --- not sure when this is called... | |
# tfds.enable_progress_bar() | |
# tfds.disable_progress_bar() | |
ds = tf.data.Dataset.range(10) #< dummy [0..9] dataset | |
ds = ds.shuffle(len(ds), seed=1) #< shuffle each epoch | |
ds = ds.repeat(4) #< num epochs | |
ds = ds.batch(5) #< batch sizes | |
def map_functor(tensor): | |
return tensor+100 | |
ds = ds.map(map_functor, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
ds = ds.cache(filename='/tmp/range10cache') #< warning: ignores changes to code! | |
ds = ds.prefetch(tf.data.experimental.AUTOTUNE) | |
ds = tfds.as_numpy(ds) #< for use in JAX (after this, cache/... not availabe) | |
for ibatch, batch in enumerate(ds): | |
print(f"batch[{ibatch+1:02}/{len(ds):02}] → {batch}") | |
# --- check dataset performance | |
tfds.benchmark(ds, batch_size=32) | |
tfds.benchmark(ds, batch_size=32) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment