Skip to content

Instantly share code, notes, and snippets.

@kingoflolz
Created March 27, 2021 10:28
Show Gist options
  • Save kingoflolz/2c63f01f388ae2706b5d395b5942e63b to your computer and use it in GitHub Desktop.
Save kingoflolz/2c63f01f388ae2706b5d395b5942e63b to your computer and use it in GitHub Desktop.
A quick script for shuffling tfrecord datasets
import tensorflow as tf
from tqdm import tqdm
index = open("data/openwebtext2_new_inputs.train.index").read().splitlines()
dataset = tf.data.Dataset.from_tensor_slices(index)
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=128, num_parallel_calls=tf.data.experimental.AUTOTUNE)
d = dataset.shuffle(10000).prefetch(100)
i = 0
for idx, example in enumerate(tqdm(d)):
if idx % 100000 == 0:
try:
writer.close()
except:
pass
writer = tf.io.TFRecordWriter(f"gs://neo-datasets/openwebtext2_new_inputs_shuffled/train/openwebtext_{i}.tfrecords")
i += 1
writer.write(example.numpy())
writer.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment