Skip to content

Instantly share code, notes, and snippets.

@sklam
Created November 29, 2016 09:54
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sklam/ab2948068f76b6b206459fa4e2b4aafc to your computer and use it in GitHub Desktop.
Save sklam/ab2948068f76b6b206459fa4e2b4aafc to your computer and use it in GitHub Desktop.
Demo for unsafe_cast to cast arbitrary pointer into a jitclass instance
"""
Demonstrate potential numba feature for casting an arbitrary pointer into
a jitclass instance.
"""
from collections import OrderedDict
import numpy as np
from numba import njit, jitclass, types
from numba import extending
from numba import cgutils
#
# Implement the core logic of the casting function.
#
def _unsafe_cast_ptr_to_class(int_type, class_type):
inst_typ = class_type.instance_type
sig = inst_typ(types.voidptr, class_type)
def codegen(context, builder, signature, args):
ptr, _ = args
alloc_type = context.get_data_type(inst_typ.get_data_type())
inst_struct = context.make_helper(builder, inst_typ)
# Set meminfo to NULL
inst_struct.meminfo = cgutils.get_null_value(inst_struct.meminfo.type)
# Set data from the given pointer
inst_struct.data = builder.bitcast(ptr, alloc_type.as_pointer())
return inst_struct._getvalue()
return sig, codegen
#
# Make an intrinsic for our unsafe_cast function
#
@extending.intrinsic
def unsafe_cast(typingctx, src, dst):
# This is the typing logic
if isinstance(src, types.Integer) and isinstance(dst, types.ClassType):
# This defines the codegen logic
return _unsafe_cast_ptr_to_class(src, dst)
raise TypeError
#
# Make a class
#
spec = OrderedDict()
spec['intval'] = types.int32
spec['realval'] = types.float32
@jitclass(spec)
class Stuff(object):
def __init__(self, x, y):
self.intval = x
self.realval = y
def get(self):
return (self.intval, self.realval)
#
# A function to exercise our intrinsic and cast a pointer to a Stuff instance
#
@njit
def foo(ptr):
# unsafe_cast is an intrinsic that can only be called inside a jit-ed function
return unsafe_cast(ptr, Stuff)
def main():
ary = np.zeros(1, dtype=np.dtype([('x', 'int32'), ('y', 'float32')]))
ary[0]['x'] = 123
ary[0]['y'] = 3.14
ptr = ary.ctypes.data
obj = foo(ptr)
print(obj.get())
print('modify the obj')
obj.intval *= 100
print(obj.get(), ary)
print('modify the array')
ary[0]['y'] *= 2
print(obj.get(), ary)
if __name__ == '__main__':
main()
@skirpichenko
Copy link

That is exactly what I was looking for. Great job, thanks for sharing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment