Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save markub3327/3ab35c8039d4c9c9eb932943aa36b24f to your computer and use it in GitHub Desktop.
Save markub3327/3ab35c8039d4c9c9eb932943aa36b24f to your computer and use it in GitHub Desktop.
Method to split a tensorflow dataset (tf.data.Dataset) into train, validation and test splits
def get_dataset_partitions_tf(ds, ds_size, train_split=0.7, val_split=0.15, test_split=0.15, shuffle=True, shuffle_size=10000, batch_size=32):
assert (train_split + test_split + val_split) == 1
if shuffle:
# Specify seed to always have the same split distribution between runs
ds = ds.shuffle(shuffle_size, seed=12, reshuffle_each_iteration=False)
train_size = int(train_split * ds_size)
val_size = int(val_split * ds_size)
train_ds = ds.take(train_size).shuffle(buffer_size=batch_size * 8).batch(batch_size)
val_ds = ds.skip(train_size).take(val_size).shuffle(buffer_size=batch_size * 8).batch(batch_size)
test_ds = ds.skip(train_size).skip(val_size).shuffle(buffer_size=batch_size * 8).batch(batch_size)
return train_ds, val_ds, test_ds
@markub3327
Copy link
Author

@angeligareta

What do you think about this solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment