Skip to content

Instantly share code, notes, and snippets.

Created July 22, 2020 11:26
Show Gist options
  • Save stefanthaler/0289559c947b0d1789e0308d45f10d1c to your computer and use it in GitHub Desktop.
Save stefanthaler/0289559c947b0d1789e0308d45f10d1c to your computer and use it in GitHub Desktop.
Tensorflow DataSet transformation that groups sequential data into buckets and truncates them instead of padding.
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import ops
from import dataset_ops
from tensorflow.python.util.tf_export import tf_export
from import group_by_window
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
def buckettrunc_by_sequence_length(
# Map function
def element_to_bucket_id(*args):
"""Return int64 id of the length bucket for this element."""
bucket_boundaries=[10, 15, 20]
seq_length = element_length_func(args)
err_msg = ("Sequence length (%i) needs to be greater then the first bucket boundary (%i) ."%(seq_length, bucket_boundaries[0]) )
tf.constant(seq_length, dtype=tf.dtypes.int64),
tf.constant(bucket_boundaries[0], dtype=tf.dtypes.int64),
boundaries = sorted(list(bucket_boundaries)) # [10, 15, 20]
buckets_min = boundaries # [10, 15, 20]
buckets_max = boundaries[1:] + [np.iinfo(np.int32).max] # [15, 20,]
conditions_c = math_ops.logical_and( # for each element,
math_ops.greater_equal(x=seq_length, y=buckets_min), # x >= y
math_ops.less(x=seq_length, y=buckets_max )) # x < y
bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
return bucket_id
# Reduce function
def batching_fn(bucket_id, grouped_dataset):
batch_size = window_size_fn(bucket_id)
boundaries = tf.constant(bucket_boundaries, dtype=tf.dtypes.int64)
bucket_boundary = boundaries[bucket_id]
begin = tf.constant(value=0, dtype=tf.dtypes.int64,name='seq_begin')
grouped_dataset = seq: tf.slice(seq, begin=[begin], size=[bucket_boundary])) # truncate to bucket boundary
return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder)
# Batch size functions
batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.dtypes.int64)
def window_size_fn(bucket_id):
window_size = batch_sizes[bucket_id]
return window_size
def _apply_fn(dataset):
return dataset.apply(group_by_window(
return _apply_fn
def seq_len(seq):
return tf.shape(seq)[0]
# data generator
def gen():
for i in [np.array([1, 1, 1]), np.array([2, 2, 2, 2, 2]), np.array([3, 3, 3, 3, 3, 3, 3])]:
yield i
# data pipeline
dataset = gen, (tf.int32), (tf.TensorShape([None])))
dataset = dataset.apply( buckettrunc_by_sequence_length(
drop_remainder=False ))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment