Skip to content

Instantly share code, notes, and snippets.

@sguada
Created November 2, 2020 10:56
Show Gist options
  • Save sguada/2e5edb83a91d05c4e7ddf22622573096 to your computer and use it in GitHub Desktop.
Save sguada/2e5edb83a91d05c4e7ddf22622573096 to your computer and use it in GitHub Desktop.
Double interleave
def make_reverb_dataset_double_interleave(
trainer: str,
num_parallel_calls: int = 8,
prefetch: int = 0,
batch_size: int = 2048,
max_in_flight_samples_per_worker = 512,
obs_dim: int = 32,
action_dim: int = 2) -> tf.data.Dataset:
all_shapes = (
tf.TensorShape((None, obs_dim)),
tf.TensorShape((None, action_dim)),
tf.TensorShape((None, 1)),
tf.TensorShape((None, 1)))
all_types = tuple(4 * [tf.dtypes.float32])
def reverb_dataset(_):
return reverb.ReplayDataset(
server_address=trainer,
table=TRANSITIONS_TABLE,
dtypes=all_types,
shapes=all_shapes,
max_in_flight_samples_per_worker=max_in_flight_samples_per_worker,
num_workers_per_iterator=1,
emit_timesteps=False)
def make_dataset(_):
dataset = tf.data.Dataset.range(num_parallel_calls)
dataset = dataset.interleave(
map_func=reverb_dataset,
cycle_length=8,
num_parallel_calls=4,
deterministic=False)
dataset = dataset.batch(batch_size)
if prefetch > 0:
dataset = dataset.prefetch(prefetch)
return dataset
return tf.data.Dataset.range(num_parallel_calls).interleave(
make_dataset,
cycle_length=num_parallel_calls,
num_parallel_calls=num_parallel_calls).prefetch(tf.data.experimental.AUTOTUNE)
return dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment