Skip to content

Instantly share code, notes, and snippets.

@Maher4Ever
Last active April 8, 2016 07:17
Show Gist options
  • Save Maher4Ever/a24bb17312b7152aead2e218891597ea to your computer and use it in GitHub Desktop.
Save Maher4Ever/a24bb17312b7152aead2e218891597ea to your computer and use it in GitHub Desktop.
import enum
import wrapt
import numba
from numba.typing.typeof import typeof_impl
from numba.lowering import Lower
from numba import ir
# Define a Numba type for enums
class Enum(numba.types.Dummy):
def __init__(self, enum):
self.enum = enum
super(Enum, self).__init__("Enum(%s)" % enum)
@property
def native_type(self):
return numba.uint32
@property
def key(self):
return self.enum
# Register the Enum type
@typeof_impl.register(enum.EnumMeta)
def _typeof_enum(val, c):
return Enum(val)
# When code containing enums is parsed it's interpreted as
# looking up the value of an attribute. Hence we need to patch
# this code when an enum is used and return its native type.
@wrapt.patch_function_wrapper('numba.typing.context', 'BaseContext.resolve_getattr')
def patch_resolving_attrs(wrapped, instance, args, kwargs):
typ = args[0]
if isinstance(typ, Enum):
return typ.native_type
return wrapped(*args, **kwargs)
# Add a dict to store enums
Lower.enummap = {}
# Patch the code responsible for lowering Python instructions
# to LLVM IR. For each enum 2 instructions are generated:
# 1. Assignment of a GlobalVariable to a register (this is the enum class)
# 2. Reading the attribute of the GlobalVariable (this is the actual enum)
# When we encounter the first GlobalVariable, it gets stored in the dict
# with the name of the register as the key.
@wrapt.patch_function_wrapper('numba.lowering', 'Lower.lower_inst')
def patch_lower_inst(wrapped, instance, args, kwargs):
inst = args[0]
if (isinstance(inst, ir.Assign) and
isinstance(inst.value, ir.Global) and
isinstance(inst.value.value, enum.EnumMeta)):
instance.enummap[inst.target.name] = inst.value.value
else:
wrapped(*args, **kwargs)
# Then when a lookup happens for an attr in lowering an expression
# we check if the attr has the same name as the register. When that's
# the case we return a constant expression with the value of the enum.
@wrapt.patch_function_wrapper('numba.lowering', 'Lower.lower_expr')
def patch_lower_expr(wrapped, instance, args, kwargs):
resty = args[0]
expr = args[1]
if expr.op == "getattr":
# Enums are the only "attributes" that are not stored, hence we
# need to make sure not to override the functionality for the valid ones.
try:
instance.loadvar(expr.value.name)
except KeyError:
if expr.value.name in instance.enummap:
enum = instance.enummap[expr.value.name]
return instance.context.get_constant_generic(instance.builder, resty,
enum[expr.attr].value)
return wrapped(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment