Last active
April 8, 2016 07:17
-
-
Save Maher4Ever/a24bb17312b7152aead2e218891597ea to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import enum | |
import wrapt | |
import numba | |
from numba.typing.typeof import typeof_impl | |
from numba.lowering import Lower | |
from numba import ir | |
# Define a Numba type for enums | |
class Enum(numba.types.Dummy): | |
def __init__(self, enum): | |
self.enum = enum | |
super(Enum, self).__init__("Enum(%s)" % enum) | |
@property | |
def native_type(self): | |
return numba.uint32 | |
@property | |
def key(self): | |
return self.enum | |
# Register the Enum type | |
@typeof_impl.register(enum.EnumMeta) | |
def _typeof_enum(val, c): | |
return Enum(val) | |
# When code containing enums is parsed it's interpreted as | |
# looking up the value of an attribute. Hence we need to patch | |
# this code when an enum is used and return its native type. | |
@wrapt.patch_function_wrapper('numba.typing.context', 'BaseContext.resolve_getattr') | |
def patch_resolving_attrs(wrapped, instance, args, kwargs): | |
typ = args[0] | |
if isinstance(typ, Enum): | |
return typ.native_type | |
return wrapped(*args, **kwargs) | |
# Add a dict to store enums | |
Lower.enummap = {} | |
# Patch the code responsible for lowering Python instructions | |
# to LLVM IR. For each enum 2 instructions are generated: | |
# 1. Assignment of a GlobalVariable to a register (this is the enum class) | |
# 2. Reading the attribute of the GlobalVariable (this is the actual enum) | |
# When we encounter the first GlobalVariable, it gets stored in the dict | |
# with the name of the register as the key. | |
@wrapt.patch_function_wrapper('numba.lowering', 'Lower.lower_inst') | |
def patch_lower_inst(wrapped, instance, args, kwargs): | |
inst = args[0] | |
if (isinstance(inst, ir.Assign) and | |
isinstance(inst.value, ir.Global) and | |
isinstance(inst.value.value, enum.EnumMeta)): | |
instance.enummap[inst.target.name] = inst.value.value | |
else: | |
wrapped(*args, **kwargs) | |
# Then when a lookup happens for an attr in lowering an expression | |
# we check if the attr has the same name as the register. When that's | |
# the case we return a constant expression with the value of the enum. | |
@wrapt.patch_function_wrapper('numba.lowering', 'Lower.lower_expr') | |
def patch_lower_expr(wrapped, instance, args, kwargs): | |
resty = args[0] | |
expr = args[1] | |
if expr.op == "getattr": | |
# Enums are the only "attributes" that are not stored, hence we | |
# need to make sure not to override the functionality for the valid ones. | |
try: | |
instance.loadvar(expr.value.name) | |
except KeyError: | |
if expr.value.name in instance.enummap: | |
enum = instance.enummap[expr.value.name] | |
return instance.context.get_constant_generic(instance.builder, resty, | |
enum[expr.attr].value) | |
return wrapped(*args, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment