Skip to content

Instantly share code, notes, and snippets.

@jpivarski
Last active April 1, 2020 16:40
Show Gist options
  • Save jpivarski/7bc83e5aa70d5e3dd8483eb49800885c to your computer and use it in GitHub Desktop.
Save jpivarski/7bc83e5aa70d5e3dd8483eb49800885c to your computer and use it in GitHub Desktop.
GrowableBuffer in Numba
import sys
import operator
import numpy
# Note: these are pre-0.49 locations; things move around in Numba 0.49.
import numba
import numba.typing.arraydecl
# First, let's define the class in Python.
class GrowableBuffer:
def __init__(self, dtype, initial=1024, resize=1.5):
assert initial > 0
assert resize > 1.0
self._initial = initial
self._resize = resize
self._buffer = numpy.empty(initial, dtype=dtype)
self._length = numpy.array([0], dtype=numpy.intp)
def __str__(self):
return str(self.__array__())
def __repr__(self):
return "growable({0})".format(str(self))
def __len__(self):
# The length is in an array so that we can update it in-place in
# lowered code.
return self._length[0]
def __getitem__(self, where):
return self._buffer[where]
def __array__(self):
return self._buffer[:self._length[0]]
@property
def reserved(self):
return len(self._buffer)
def _ensure_reserved(self):
# This is called infrequently enough that we can have the lowered
# code call this Python function. That way, we don't have to
# reproduce this logic in lowered code.
while self._length[0] >= len(self._buffer):
reservation = int(numpy.ceil(len(self._buffer) * self._resize))
newbuffer = numpy.empty(reservation, dtype=self._buffer.dtype)
newbuffer[:len(self._buffer)] = self._buffer
self._buffer = newbuffer
def append(self, what):
# This is the logic we will have to reproduce in the lowered code
# because it's called frequently.
if self._length[0] >= len(self._buffer):
self._ensure_reserved()
self._buffer[self._length[0]] = what
self._length[0] += 1
# To start Numbafying this class, we define a Type. This is everything we
# need to know at compile-time.
class GrowableBufferType(numba.types.Type):
def __init__(self, dtype):
# This type depends on the dtype of the data (int64, float32, etc.).
super(GrowableBufferType, self).__init__(name=
"GrowableBufferType({0})".format(dtype.name))
self.dtype = dtype
# We often need to know the type of the buffer array, so construct it
# from the dtype.
@property
def buffertype(self):
# dtype, 1-dim, C-contiguous
return numba.types.Array(self.dtype, 1, "C")
# Next, we have to identify the Type from a Python instance.
@numba.extending.typeof_impl.register(GrowableBuffer)
def typeof_GrowableBuffer(growablebuffer, c):
return GrowableBufferType(
numba.from_dtype(growablebuffer._buffer.dtype))
# Next, we define what information is available at runtime.
# A model is a copy-by-value struct, and is therefore immutable.
# However, we need update the length with every call to 'append', and we
# need to update the buffer whenever the reservation changes.
# So we do it with pointers: we'll allocate the "buffer" pointer and fill it
# ourselves, but "length" will point to the one-element array from the
# Python object. The "pyobj" is a reference-counted pointer to the
# original Python object.
@numba.extending.register_model(GrowableBufferType)
class GrowableBufferModel(numba.datamodel.models.StructModel):
def __init__(self, dmm, fe_type):
members = [("buffer", numba.types.CPointer(fe_type.buffertype)),
("length", numba.types.CPointer(numba.intp)),
("pyobj", numba.types.pyobject)]
super(GrowableBufferModel, self).__init__(dmm, fe_type, members)
# "Unboxing" means converting a Python object into a lowered model.
# This function generates LLVM assembly to do the transformation.
@numba.extending.unbox(GrowableBufferType)
def unbox_GrowableBuffer(typ, obj, c):
# To build the lowered model, we have to extract some attributes from
# the Python object "obj". These are Python-C API calls (through c.pyapi).
buffer_obj = c.pyapi.object_getattr_string(obj, "_buffer")
length_obj = c.pyapi.object_getattr_string(obj, "_length")
ctypes_obj = c.pyapi.object_getattr_string(length_obj, "ctypes")
lenptr_obj = c.pyapi.object_getattr_string(ctypes_obj, "data")
# A proxy helps us generate LLVM assembly for getting or setting model
# attributes. If constructed without "value" (as below), we *set* values.
proxy = c.context.make_helper(c.builder, typ)
# For the "buffer" model attribute, we generate an instruction to allocate
# memory to hold a lowered NumPy array object. "alloca_once" returns a
# pointer, so we're setting proxy.buffer to a pointer.
proxy.buffer = numba.cgutils.alloca_once(c.builder,
c.context.get_value_type(typ.buffertype))
# builder.store(value, pointer) assigns to the newly allocated memory.
c.builder.store(c.pyapi.to_native_value(typ.buffertype, buffer_obj).value,
proxy.buffer)
# The "length" is a pointer, too, but instead of allocating space for an
# integer (numba.intp, which is ssize_t), we'll take the already allocated
# "length" array in the GrowableBuffer Python object (which has room for
# one integer).
proxy.length = c.builder.inttoptr(
c.pyapi.number_as_ssize_t(lenptr_obj),
c.context.get_value_type(numba.types.CPointer(numba.intp)))
# Assign the Python object to this model.
proxy.pyobj = obj
# Turn the proxy into a value. (The underscored function is necessary.)
out = proxy._getvalue()
# Since this model contains a Python reference, ask Numba's RunTime (NRT)
# to reference-count it.
if c.context.enable_nrt:
c.context.nrt.incref(c.builder, typ, out)
# All the Python objects we created in this function have to be decrefed.
c.pyapi.decref(buffer_obj)
c.pyapi.decref(buffer_obj)
c.pyapi.decref(length_obj)
c.pyapi.decref(ctypes_obj)
c.pyapi.decref(lenptr_obj)
# Check for an error and return.
is_error = numba.cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return numba.extending.NativeValue(out, is_error)
# "Boxing" means converting a lowered model into a Python object.
# Since we have a refernce to the Python object, we just have to return that.
@numba.extending.box(GrowableBufferType)
def box_GrowableBuffer(typ, val, c):
# This time, we construct the proxy with a value, so we'll be *getting*
# fields from it.
proxy = c.context.make_helper(c.builder, typ, value=val)
# And increment the Python object because we're returning a new reference
# to it.
c.pyapi.incref(proxy.pyobj)
return proxy.pyobj
# Define the type for __getitem__.
@numba.typing.templates.infer_global(operator.getitem)
class type_getitem(numba.typing.templates.AbstractTemplate):
def generic(self, args, kwargs):
# If this raises an error or returns None, Numba will swallow the
# error and keep checking other possible types.
if (len(args) == 2 and len(kwargs) == 0 and
isinstance(args[0], GrowableBufferType)):
# This __getitem__ is generic: wheretype could be an integer,
# but it could also be an array or anything __getitem__ takes.
# Since we pass it on to Numba's handling of NumPy __getitem__,
# we get all the functionality of Numba's NumPy handling.
growabletype, wheretype = args
outtype = numba.typing.arraydecl.get_array_index_type(
growabletype.buffertype, wheretype).result
# The output is a Signature, which we construct as outtype(args...)
return outtype(growabletype, wheretype)
# The lowering function for __getitem__ with an integer or slice.
@numba.extending.lower_builtin(operator.getitem, GrowableBufferType,
numba.types.Integer)
@numba.extending.lower_builtin(operator.getitem, GrowableBufferType,
numba.types.SliceType)
def lower_getitem_int_slice(context, builder, sig, args):
growablebuffertype, wheretype = sig.args
growablebufferval, whereval = args
proxy = context.make_helper(builder, growablebuffertype, growablebufferval)
# Trim the buffer to the length of the length of the valid part.
trimmed = trim(context, builder, growablebuffertype.buffertype,
builder.load(proxy.buffer), builder.load(proxy.length))
# Calls Numba's function for NumPy __getitem__ with an integer or slice.
return numba.targets.arrayobj.getitem_arraynd_intp(context, builder,
sig.return_type(growablebuffertype.buffertype, wheretype),
(trimmed, whereval))
# The lowering function for __getitem__ with any other type.
@numba.extending.lower_builtin(operator.getitem, GrowableBufferType,
numba.types.Any)
def lower_getitem_array(context, builder, sig, args):
growablebuffertype, wheretype = sig.args
growablebufferval, whereval = args
proxy = context.make_helper(builder, growablebuffertype, growablebufferval)
# Trim the buffer to the length of the length of the valid part.
trimmed = trim(context, builder, growablebuffertype.buffertype,
builder.load(proxy.buffer), builder.load(proxy.length))
# Calls Numba's function for NumPy __getitem__ with an integer or slice.
return numba.targets.arrayobj.fancy_getitem_array(context, builder,
sig.return_type(growablebuffertype.buffertype, wheretype),
(trimmed, whereval))
# This "trim" function uses Numba's __getitem__ again.
def trim(context, builder, buffertype, bufferarray, length):
sliceproxy = context.make_helper(builder, numba.types.slice2_type)
sliceproxy.start = context.get_constant(numba.intp, 0)
sliceproxy.stop = length
sliceproxy.step = context.get_constant(numba.intp, 1)
return numba.targets.arrayobj.getitem_arraynd_intp(context, builder,
buffertype(buffertype, numba.types.slice2_type),
(bufferarray, sliceproxy._getvalue()))
# Define the type for __len__.
@numba.typing.templates.infer_global(len)
class type_len(numba.typing.templates.AbstractTemplate):
def generic(self, args, kwargs):
if (len(args) == 1 and len(kwargs) == 0 and
isinstance(args[0], GrowableBufferType)):
# This one is simple: take a GrowableBuffer in, return an intp.
return numba.intp(args[0])
# The lowering function for __len__ is easy enough to do it directly.
@numba.extending.lower_builtin(len, GrowableBufferType)
def lower_len(context, builder, sig, args):
growablebuffertype, = sig.args
growablebufferval, = args
# Create a proxy from the input value, as before.
proxy = context.make_helper(builder, growablebuffertype, growablebufferval)
# And dereference the "length" pointer to return the appropriate value.
return builder.load(proxy.length)
# Define the type for attributes and methods.
@numba.typing.templates.infer_getattr
class GrowableBuffer_attrib(numba.typing.templates.AttributeTemplate):
key = GrowableBufferType
# This method defines all the attributes. Now the return value is a Type,
# not a Signature.
def generic_resolve(self, growablebuffertype, attr):
if attr == "_buffer":
return growablebuffertype.buffertype
elif attr == "reserved":
return numba.intp
# The methods could be defined with generic_resolve, but it's easier to
# use the bound_function decorator. For the methods, we return Signatures.
@numba.typing.templates.bound_function("_ensure_reserved")
def resolve__ensure_reserved(self, growablebuffertype, args, kwargs):
if len(args) == 0 and len(kwargs) == 0:
return numba.types.none()
@numba.typing.templates.bound_function("append")
def resolve_append(self, growablebuffertype, args, kwargs):
if (len(args) == 1 and len(kwargs) == 0 and
isinstance(args[0], numba.types.Number)):
return numba.types.none(args[0])
# The lowering function for all attributes.
@numba.extending.lower_getattr_generic(GrowableBufferType)
def lower_getattr_generic(context, builder,
growablebuffertype, growablebufferval,
attr):
proxy = context.make_helper(builder, growablebuffertype,
value=growablebufferval)
if attr == "_buffer":
# Dereference the "buffer" pointer as we did with the "length" pointer
# before.
return builder.load(proxy.buffer)
elif attr == "reserved":
# This calls Numba's __len__ implementation for NumPy arrays.
sig = numba.types.intp(growablebuffertype.buffertype)
args = (builder.load(proxy.buffer),)
return numba.targets.arrayobj.array_len(context, builder, sig, args)
# The lowering function for the _ensure_reserved method. For this one,
# we dont't want to reimplement the logic in lowered Numba.
@numba.extending.lower_builtin("_ensure_reserved",
GrowableBufferType, numba.types.Integer)
def lower__ensure_reserved(context, builder, sig, args):
growablebuffertype, = sig.args
growablebufferval, = args
proxy = context.make_helper(builder, growablebuffertype,
value=growablebufferval)
# To call Python from a lowered function, we need to get a Python API.
pyapi = context.get_python_api(builder)
# And this means acquiring the Global Interpreter Lock (GIL).
gil = pyapi.gil_ensure()
# Call the Python function.
pyapi.incref(proxy.pyobj)
none_obj = pyapi.call_method(proxy.pyobj, "_ensure_reserved", ())
# Since this has changed the buffer, we need to replace the lowered
# NumPy array in our "_buffer" pointer with a new one. This is a
# subset of the unboxing code.
newbuffer_obj = pyapi.object_getattr_string(proxy.pyobj, "_buffer")
newbufferval = pyapi.to_native_value(growablebuffertype.buffertype,
newbuffer_obj).value
# Assign it to the pointer!
builder.store(newbufferval, proxy.buffer)
# Decrement all those Python objects!
pyapi.decref(newbuffer_obj)
pyapi.decref(newbuffer_obj)
pyapi.decref(proxy.pyobj)
pyapi.decref(none_obj)
# Release the GIL!
pyapi.gil_release(gil)
# This function returns the lowered equivalent of None.
return context.get_dummy_value()
# The lowering function for the append method. Unlike _ensure_reserved, this
# one is called frequently and has to be fast. We will *not* defer to the
# Python implementation (or acquire the GIL, or anything like that).
@numba.extending.lower_builtin("append",
GrowableBufferType, numba.types.Number)
def lower_append(context, builder, sig, args):
growablebuffertype, numbertype = sig.args
growablebufferval, numberval = args
proxy = context.make_helper(builder, growablebuffertype,
value=growablebufferval)
# Get the current length and the size of the buffer to see if we have
# to call _ensure_reserved.
lengthval = builder.load(proxy.length)
reservedval = numba.targets.arrayobj.array_len(context, builder,
numba.types.intp(growablebuffertype.buffertype),
(builder.load(proxy.buffer),))
# LLVM for an "if" statement can be generated using a context manager.
# (Remember that this function doesn't *run* append, it generates the
# code for append.)
#
# builder.icmp_signed(">=", ...) generates the code for the predicate.
#
# likely=False is a compiler hint that the predicate is rarely true.
#
# Compile-time control flow always goes through the "with" body, but
# run-time control flow rarely enters the "if" body that is generated.
with builder.if_then(builder.icmp_signed(">=", lengthval, reservedval),
likely=False):
ensure_sig = numba.types.none(growablebuffertype)
lower__ensure_reserved(context, builder,
ensure_sig, (growablebufferval,))
# We have to make the proxy again so that we get the post-updated buffer
# in case the _ensure_reserved was called.
newproxy = context.make_helper(builder, growablebuffertype,
value=growablebufferval)
# Now call Numba's __setitem__ on the buffer array to write a new value
# at the current "length".
setitem_sig = numba.types.none(growablebuffertype.buffertype,
numba.intp,
numbertype)
numba.targets.arrayobj.setitem_array(context, builder,
setitem_sig, (builder.load(newproxy.buffer), lengthval, numberval))
# Add one to the length and store it in its place.
builder.store(builder.add(lengthval, context.get_constant(numba.intp, 1)),
newproxy.length)
# Return None.
return context.get_dummy_value()
############################################################ tests
# Use a GrowableBuffer in Python.
buf = GrowableBuffer(float, initial=10)
buf.append(1.1)
buf.append(2.2)
buf.append(3.3)
# Get another reference to it so we can check its reference count.
tmp = buf._buffer
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
@numba.njit
def test1(x):
return 3.14
# Test 1: unboxing doesn't crash.
test1(buf)
# Keep calling it and ensure that the reference counts don't grow.
for i in range(10):
test1(buf)
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
@numba.njit
def test2(x):
return x
# Test 2: unboxing and boxing doesn't crash and returns a usable object.
assert numpy.asarray(test2(buf)).tolist() == [1.1, 2.2, 3.3]
# Keep calling it and ensure that the reference counts don't grow.
for i in range(10):
test2(buf)
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
@numba.njit
def test3(x):
return x, x
# Test3: do that returning two references for every one that goes in, to make
# sure that the above didn't pass by accident.
for i in range(10):
test3(buf)
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
@numba.njit
def test4(x, i):
return x[i]
# Test 4: verify that __getitem__ works.
assert test4(buf, 0) == 1.1
assert test4(buf, 1) == 2.2
assert test4(buf, 2) == 3.3
for i in range(10):
test4(buf, 0)
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
# ... for integers
assert test4(buf, 1) == 2.2
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
# ... for slices
assert test4(buf, slice(1, None)).tolist() == [2.2, 3.3]
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
# ... for arrays
assert test4(buf, numpy.array([2, 1, 1, 0])).tolist() == [3.3, 2.2, 2.2, 1.1]
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
@numba.njit
def test5(x):
return len(x), x.reserved
# Test 5: verify that __len__ works and the "reserved" property works.
assert test5(buf) == (3, 10)
@numba.njit
def test6(x):
x.append(4.4)
x.append(5.5)
# Test 6: verify that we can append to the GrowableBuffer.
assert numpy.asarray(buf).tolist() == [1.1, 2.2, 3.3]
test6(buf)
assert numpy.asarray(buf).tolist() == [1.1, 2.2, 3.3, 4.4, 5.5]
# Go through 3 resizings to make sure it doesn't crash and none of the
# reference counts grow.
tmp = buf._buffer
for i in range(30):
test6(buf)
assert sys.getrefcount(tmp) in (2, 3)
tmp = buf._buffer
assert (sys.getrefcount(buf), sys.getrefcount(tmp)) == (2, 3)
# The final value should have a lot of 4.4's and 5.5's in it.
assert numpy.asarray(buf).tolist() == [1.1, 2.2, 3.3, 4.4, 5.5] + [4.4, 5.5]*30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment