Skip to content

Instantly share code, notes, and snippets.

@Sirsirious
Created December 4, 2020 13:38
Show Gist options
  • Save Sirsirious/4bf440e1be67af88dabe751ca17a8b2f to your computer and use it in GitHub Desktop.
Save Sirsirious/4bf440e1be67af88dabe751ca17a8b2f to your computer and use it in GitHub Desktop.
# First we get the streams from TFDS
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
# Next, we build the pipeline
data_pipeline = trax.data.Serial(
trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
trax.data.Shuffle(),
trax.data.FilterByLength(max_length=2048, length_keys=[0]),
trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],
batch_sizes=[512, 128, 32, 8, 1],
length_keys=[0]),
trax.data.AddLossWeights()
)
# Finally, we get the generators
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment