Skip to content

Instantly share code, notes, and snippets.

@avmoldovan
Forked from louiskirsch/tf_ring_buffer.py
Created June 12, 2020 18:47
Show Gist options
  • Save avmoldovan/ab361a24b1be3ab90ec064ca53da6ccd to your computer and use it in GitHub Desktop.
Save avmoldovan/ab361a24b1be3ab90ec064ca53da6ccd to your computer and use it in GitHub Desktop.
A tensorflow ring buffer implementation
class RingBuffer:
def __init__(self, scope_name, components, size):
"""
Create a new ring buffer of size `size`.
Each item in the ring buffer is a tuple of variables of size `len(components)`.
:param scope_name: A scope name for the newly created variables
:param components: Defines the type of items in the buffer. An iterable of tuples (name: str, shape: Iterable, dtype)
:param size: The maximum size of the buffer
"""
self.size = size
with tf.variable_scope(scope_name, initializer=tf.zeros_initializer()):
self.components = [tf.get_variable(name, [size] + list(shape), dtype) for name, shape, dtype in components]
self.offset = tf.get_variable('offset', shape=[], dtype=tf.int32)
def insert(self, tensors):
elem_count = tensors[0].shape.as_list()[0]
ops = []
for tensor, component in zip(tensors, self.components):
assert tensor.shape.as_list()[0] == elem_count
# Fill the tail of the buffer
start = self.offset
end = tf.minimum(self.size, self.offset + elem_count)
fill_count = end - start
ops.append(component[start:end].assign(tensor[:fill_count]))
# Fill the front of the buffer if elements are still left
end = elem_count - fill_count
ops.append(component[:end].assign(tensor[fill_count:]))
with tf.control_dependencies(ops):
ops.append(self.offset.assign((self.offset + elem_count) % self.size))
return tf.group(*ops)
def sample(self, count):
indices = tf.random_shuffle(tf.range(self.size))[:count]
return [tf.gather(component, indices) for component in self.components]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment