Created
November 29, 2016 09:54
-
-
Save sklam/ab2948068f76b6b206459fa4e2b4aafc to your computer and use it in GitHub Desktop.
Demo for unsafe_cast to cast arbitrary pointer into a jitclass instance
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Demonstrate potential numba feature for casting an arbitrary pointer into | |
a jitclass instance. | |
""" | |
from collections import OrderedDict | |
import numpy as np | |
from numba import njit, jitclass, types | |
from numba import extending | |
from numba import cgutils | |
# | |
# Implement the core logic of the casting function. | |
# | |
def _unsafe_cast_ptr_to_class(int_type, class_type): | |
inst_typ = class_type.instance_type | |
sig = inst_typ(types.voidptr, class_type) | |
def codegen(context, builder, signature, args): | |
ptr, _ = args | |
alloc_type = context.get_data_type(inst_typ.get_data_type()) | |
inst_struct = context.make_helper(builder, inst_typ) | |
# Set meminfo to NULL | |
inst_struct.meminfo = cgutils.get_null_value(inst_struct.meminfo.type) | |
# Set data from the given pointer | |
inst_struct.data = builder.bitcast(ptr, alloc_type.as_pointer()) | |
return inst_struct._getvalue() | |
return sig, codegen | |
# | |
# Make an intrinsic for our unsafe_cast function | |
# | |
@extending.intrinsic | |
def unsafe_cast(typingctx, src, dst): | |
# This is the typing logic | |
if isinstance(src, types.Integer) and isinstance(dst, types.ClassType): | |
# This defines the codegen logic | |
return _unsafe_cast_ptr_to_class(src, dst) | |
raise TypeError | |
# | |
# Make a class | |
# | |
spec = OrderedDict() | |
spec['intval'] = types.int32 | |
spec['realval'] = types.float32 | |
@jitclass(spec) | |
class Stuff(object): | |
def __init__(self, x, y): | |
self.intval = x | |
self.realval = y | |
def get(self): | |
return (self.intval, self.realval) | |
# | |
# A function to exercise our intrinsic and cast a pointer to a Stuff instance | |
# | |
@njit | |
def foo(ptr): | |
# unsafe_cast is an intrinsic that can only be called inside a jit-ed function | |
return unsafe_cast(ptr, Stuff) | |
def main(): | |
ary = np.zeros(1, dtype=np.dtype([('x', 'int32'), ('y', 'float32')])) | |
ary[0]['x'] = 123 | |
ary[0]['y'] = 3.14 | |
ptr = ary.ctypes.data | |
obj = foo(ptr) | |
print(obj.get()) | |
print('modify the obj') | |
obj.intval *= 100 | |
print(obj.get(), ary) | |
print('modify the array') | |
ary[0]['y'] *= 2 | |
print(obj.get(), ary) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That is exactly what I was looking for. Great job, thanks for sharing!