Truncating and grouping sequential data in TensorFlow

Sequential data, i.e. ordered collections of types, is a fundamental and frequently used data type in deep learning. For example, text can be translated by treating sentences as sequences of words. Often, sequences for learning such models are of varying length, which poses a problem for learning deep learning models.

TensorFlow offers mechanisms to learn deep learning models from sequences with varying lengths. For example, LSTMs – one of the most prominent deep models for sequential data – are implemented in TensorFlow using a tf.while_loop, which dynamically supports models up to arbitrary length. Dynamically in this context means that the sequence lengths do not need to be known beforehand.

TensorFlow’s while mechanism is very flexible, as it neither needs to know the length of the sequences nor the size of the mini-batches beforehand. However, one downside of this mechanism is that only sequences of similar lengths can be batched together, thus making the learning efficient.

To enable efficient learning for learning tasks where elements have varying lengths, TensorFlow (experimentally) provides a method, that groups and pads sequences to certain buckets by lengths. So for example, if you want to have two buckets of lengths 5 and 10, all sequences in your data set will be associated with either one of the buckets and padded until they have either length 5 or 10.

Padding is useful for sequences that have a beginning and an end, however, it does not make sense for sequences are of unlimited lengths and hence the padding values are unknown. In such a scenario, truncating sequences is more useful.

In this post, I describe how to implement a TensorFlow method that truncates sequences and groups them by lengths so that models can be learned more efficiently. A gist with the complete sample code can be found here.

Solution sketch

Our goal is to create a TensorFlow method that puts sequences into “buckets” of a certain length, and truncates them to the minimum length of a bucket. A bucket is defined by a minimum and maximum sequence length. For example, we could define two buckets, 5-9 and 10-14. A sequence of length 6 would be put into the first bucket and truncated to 5, and a sequence of length 13 put into the second bucket and truncated to length 10.

TensorFlow experimentally supports grouping datasets using the group_by_window function, which internally uses a MapReduce algorithm to group a dataset using a given criterion. This means that for our goal we need two functions: a map function that tells us in which bucket a sequence goes, and a reduce function that tells us how we should create the batches for each bucket. The reduce function is also the function where we truncate the sequences.

Additionally, we either need a function that tells us the batch size for each bucket or specifies the batch size for all buckets. Specifying the batch size may help us to use all memory of our GPUs, as we can specify larger batches for smaller sequence lengths and smaller batches for larger sequence lengths.


For the implementation of our solution, we follow the structure of our solution sketch. We first describe the Map function, which maps a sequence to a bucket by length. We then describe the reduce function, that truncates and batches the sequences.

Map function

The goal of our map function is to map sequences to a bucket given the sequence length. So for example, given two buckets, bucket 1 with 10-15 and bucket 2 with 15-20, map a sequence with length 11 to bucket 1 and a sequence with length 18 to bucket 2.

In Python, such a method could be easily implemented with a few lines of code, for example:

sequence = [1,2,3,4]
bucket_boundaries = [2,5,8]
for i in range(len(bucket_boundaries)-1):
  if bucket_boundaries[i] <= len(sequence) < bucket_boundaries[i+1]: return i

In TensorFlow, the base unit for calculations are Tensors and a computation graph, which offer the advantage to perform calculations efficiently in batches on a GPU, however, it requires a bit of rethinking at the implementation. One way to implement that logic is by using the TensorFlow where function, which tells us, at which position a certain condition is true. To that end, we need to define two arrays with the bucked boundaries, one containing the lower bounds and one containing the upper bounds, and then use TensorFlow’s logical conditions to define the condition. Finally, we can remove the extra dimension that was introduced by querying the arrays using reduce_min (or any other function that provides similar functionality).

An implementation of the described code could look like this:

def element_to_bucket_id(*args):
  seq_length = tf.shape(args)[0]
  # check whether length is larger than minimum bucket

  # bucket boundaries
  bucket_boundaries=[10, 15, 20]
  boundaries = sorted(list(bucket_boundaries)) # [10, 15, 20]
  buckets_min = boundaries  # [10, 15, 20]
  buckets_max = boundaries[1:] + [np.iinfo(np.int32).max]  # [15, 20,]

  # condition
  conditions_c = tf.math.logical_and( # for each element,
      tf.math.greater_equal(x=seq_length, y=buckets_min), # x >= y
      tf.math.less(x=seq_length, y=buckets_max  ))  # x < y

  # obtain bucket id
  bucket_id = tf.math.reduce_min(tf.where(conditions_c))
  return bucket_id

Reduce function

The reduce function here takes care of creating the batches and truncating the sequences. The group_by_window function will return a dataset (that follows TensorFlow’s DataSet API convention) for each bucket. So in this function, we need to truncate the sequences and create the batches.

One way to truncate the sequences of the dataset is by using TensorFlow’s slice function, which simply can be applied to the dataset using a lambda function.

Batching can be achieved using the batch method. An implementation of the reduce function could look like this:

def batching_fn(bucket_id, grouped_dataset):
    batch_size = window_size_fn(bucket_id)  
    boundaries = tf.constant(bucket_boundaries, dtype=tf.dtypes.int64)
    bucket_boundary = boundaries[bucket_id]
    begin = tf.constant(value=0, dtype=tf.dtypes.int64,name='seq_begin')          

    grouped_dataset = seq: tf.slice(seq, begin=[begin], size=[bucket_boundary])) # truncate to bucket boundary
    return grouped_dataset.batch(batch_size, drop_remainder=drop_remainder)

Batch size function

Defining the batch size using a function allows us to dynamically change the batch size based on the bucket definition, so potentially making your code perform better, as the memory on your GPU can be maxed out.

One of the simplest forms of such a function takes the bucket id and maps it to a batch size.

batch_sizes = [10, 20]
def window_size_fn(bucket_id):
    window_size = batch_sizes[bucket_id]
    return window_size

Application function

TensorFlow’s DataSet API requires us to define a function that can be applied to a dataset to perform a transformation, rather than the result of a transformation. Such a function takes only a dataset, and it performs the custom transformation if called. Here, we only call apply the group_by_window function, that takes the map and reduce functions as parameters. An implementation of such an application function could look like this:

def apply_fn(dataset):
  return dataset.apply(group_by_window(


In this blog post, I described how to implement a dataset transformation that groups sequences to predefined buckets by length, and truncates them to the minimum length of such a bucket.

A gist with the complete sample code can be found here.