Skip to content

Instantly share code, notes, and snippets.

@DannyWeitekamp
Created April 7, 2022 16:22
Show Gist options
  • Save DannyWeitekamp/e7d8c0a42df44cf4ddcf439546590714 to your computer and use it in GitHub Desktop.
Save DannyWeitekamp/e7d8c0a42df44cf4ddcf439546590714 to your computer and use it in GitHub Desktop.
from numba import njit, f8
from numba.typed import List
from numba.extending import models, register_model
class Interval(object):
"""
A half-open interval on the real number line.
"""
def __init__(self, lo, hi):
self.lo = lo
self.hi = hi
def __repr__(self):
return 'Interval(%f, %f)' % (self.lo, self.hi)
@property
def width(self):
return self.hi - self.lo
from numba import types
class IntervalType(types.Type):
def __init__(self):
super(IntervalType, self).__init__(name='Interval')
from numba.extending import typeof_impl
@typeof_impl.register(Interval)
def typeof_index(val, c):
return interval_type
interval_type = IntervalType()
from numba.extending import as_numba_type
from numba.extending import type_callable
@type_callable(Interval)
def type_interval(context):
def typer(lo, hi):
if isinstance(lo, types.Float) and isinstance(hi, types.Float):
return interval_type
return typer
as_numba_type.register(Interval, interval_type)
from numba.extending import models, register_model
@register_model(IntervalType)
class IntervalModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('lo', types.float64),
('hi', types.float64),
]
models.StructModel.__init__(self, dmm, fe_type, members)
from numba.extending import make_attribute_wrapper
make_attribute_wrapper(IntervalType, 'lo', 'lo')
make_attribute_wrapper(IntervalType, 'hi', 'hi')
from numba.extending import overload_attribute
@overload_attribute(IntervalType, "width")
def get_width(interval):
def getter(interval):
return interval.hi - interval.lo
return getter
from numba.extending import lower_builtin
from numba.core import cgutils
@lower_builtin(Interval, types.Float, types.Float)
def impl_interval(context, builder, sig, args):
typ = sig.return_type
lo, hi = args
interval = cgutils.create_struct_proxy(typ)(context, builder)
interval.lo = lo
interval.hi = hi
return interval._getvalue()
from numba.extending import unbox, NativeValue
@unbox(IntervalType)
def unbox_interval(typ, obj, c):
"""
Convert a Interval object to a native interval structure.
"""
lo_obj = c.pyapi.object_getattr_string(obj, "lo")
hi_obj = c.pyapi.object_getattr_string(obj, "hi")
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
interval.lo = c.pyapi.float_as_double(lo_obj)
interval.hi = c.pyapi.float_as_double(hi_obj)
c.pyapi.decref(lo_obj)
c.pyapi.decref(hi_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
from numba.extending import box
@box(IntervalType)
def box_interval(typ, val, c):
"""
Convert a native interval structure to an Interval object.
"""
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
lo_obj = c.pyapi.float_from_double(interval.lo)
hi_obj = c.pyapi.float_from_double(interval.hi)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
res = c.pyapi.call_function_objargs(class_obj, (lo_obj, hi_obj))
c.pyapi.decref(lo_obj)
c.pyapi.decref(hi_obj)
c.pyapi.decref(class_obj)
return res
from numba import jit
@jit(nopython=True)
def inside_interval(interval, x):
return interval.lo <= x < interval.hi
@jit(nopython=True)
def interval_width(interval):
return interval.width
@jit(nopython=True)
def sum_intervals(i, j):
return Interval(i.lo + j.lo, i.hi + j.hi)
assert inside_interval(Interval(1.0,5.0),4) == True
assert inside_interval(Interval(1.0,5.0),6) == False
print(interval_width(Interval(1.0,6.0)))
print(sum_intervals(Interval(1.0,6.0),Interval(1.0,6.0)))
###########
## ^ Above all from https://numba.pydata.org/numba-doc/latest/extending/interval-example.html
## v Below a test of implementing "getiter"
###########
from numba import f8
from numba.extending import lower_builtin
@lower_builtin("getiter", IntervalType)
def iterval_getiter(context, builder, sig, args):
print("THIS SHOULD GET PRINTED!!",args[0])
return "WHATEVER"
@jit(nopython=True)
def iter_interval(i, j):
for i in Interval(f8(i),f8(j)):
print(i)
iter_interval(1,10)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment