Skip to content

Instantly share code, notes, and snippets.

@stefanthaler
Created July 22, 2020 11:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stefanthaler/0289559c947b0d1789e0308d45f10d1c to your computer and use it in GitHub Desktop.
Save stefanthaler/0289559c947b0d1789e0308d45f10d1c to your computer and use it in GitHub Desktop.
Tensorflow DataSet transformation that groups sequential data into buckets and truncates them instead of padding.
#
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.data.experimental import group_by_window
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
def buckettrunc_by_sequence_length(
element_length_func,
bucket_boundaries,
bucket_batch_sizes,
drop_remainder=False):
# Map function
def element_to_bucket_id(*args):
"""Return int64 id of the length bucket for this element."""
bucket_boundaries=[10, 15, 20]
seq_length = element_length_func(args)
err_msg = ("Sequence length (%i) needs to be greater then the first bucket boundary (%i) ."%(seq_length, bucket_boundaries[0]) )
tf.assert_greater(
tf.constant(seq_length, dtype=tf.dtypes.int64),
tf.constant(bucket_boundaries[0], dtype=tf.dtypes.int64),
message=err_msg)
boundaries = sorted(list(bucket_boundaries)) # [10, 15, 20]
buckets_min = boundaries # [10, 15, 20]
buckets_max = boundaries[1:] + [np.iinfo(np.int32).max] # [15, 20, np.int.max]
conditions_c = math_ops.logical_and( # for each element,
math_ops.greater_equal(x=seq_length, y=buckets_min), # x >= y
math_ops.less(x=seq_length, y=buckets_max )) # x < y
bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
return bucket_id
# Reduce function
def batching_fn(bucket_id, grouped_dataset):
batch_size = window_size_fn(bucket_id)
boundaries = tf.constant(bucket_boundaries, dtype=tf.dtypes.int64)
bucket_boundary = boundaries[bucket_id]
begin = tf.constant(value=0, dtype=tf.dtypes.int64,name='seq_begin')
grouped_dataset = grouped_dataset.map(lambda seq: tf.slice(seq, begin=[begin], size=[bucket_boundary])) # truncate to bucket boundary
return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder)
# Batch size functions
batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.dtypes.int64)
def window_size_fn(bucket_id):
window_size = batch_sizes[bucket_id]
return window_size
def _apply_fn(dataset):
return dataset.apply(group_by_window(
key_func=element_to_bucket_id,
reduce_func=batching_fn,
window_size_func=window_size_fn)
)
return _apply_fn
def seq_len(seq):
return tf.shape(seq)[0]
# data generator
def gen():
for i in [np.array([1, 1, 1]), np.array([2, 2, 2, 2, 2]), np.array([3, 3, 3, 3, 3, 3, 3])]:
yield i
# data pipeline
dataset = tf.data.Dataset.from_generator( gen, (tf.int32), (tf.TensorShape([None])))
dataset = dataset.apply( buckettrunc_by_sequence_length(
element_length_func=seq_len,
bucket_boundaries=[3,7],
bucket_batch_sizes=[2,2],
drop_remainder=False ))
list(dataset.take(3).as_numpy_iterator())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment