Skip to content

Instantly share code, notes, and snippets.

@hugohadfield
Last active June 5, 2020 13:08
Show Gist options
  • Save hugohadfield/e238a139d13cd0dd4e17bb39b7b577bd to your computer and use it in GitHub Desktop.
Save hugohadfield/e238a139d13cd0dd4e17bb39b7b577bd to your computer and use it in GitHub Desktop.
import numba
import ast
import astpretty
import time
import inspect
from numba.extending import overload
from numba import types
import numpy as np
from clifford.g3c import *
gmt_func = layout.gmt_func
omt_func = layout.omt_func
imt_func = layout.imt_func
e1_val = e1.value
e2_val = e2.value
# Get a scalar promotion function
def get_as_ga_func(layout):
ndims = layout.gaDims
@numba.njit
def as_ga(x):
op = np.zeros(ndims)
op[0] = x
return op
return as_ga
as_ga = get_as_ga_func(layout)
def ga_add(x):
# dummy function to overload
pass
@overload(ga_add, inline='always')
def ol_ga_add(a, b):
if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array):
def impl(a, b):
op = b.astype(np.float32)
op[0] += a
return op
return impl
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)):
def impl(a, b):
op = a.astype(np.float32)
op[0] += b
return op
return impl
else:
def impl(a, b):
return a + b
return impl
def ga_sub(x):
# dummy function to overload
pass
@overload(ga_sub, inline='always')
def ol_ga_sub(a, b):
if isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array):
def impl(a, b):
op = -b.astype(np.float32)
op[0] += a
return op
return impl
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)):
def impl(a, b):
op = a.astype(np.float32)
op[0] -= b
return op
return impl
else:
def impl(a, b):
return a - b
return impl
def ga_mul(x):
# dummy function to overload
pass
@overload(ga_mul, inline='always')
def ol_ga_mul(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return gmt_func(a, b)
return impl
else:
def impl(a, b):
return a*b
return impl
def ga_xor(x):
# dummy function to overload
pass
@overload(ga_xor, inline='always')
def ol_ga_xor(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)):
def impl(a, b):
return omt_func(a, as_ga(b))
return impl
elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array) :
def impl(a, b):
return omt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a^b
return impl
def ga_or(x):
# dummy function to overload
pass
@overload(ga_or, inline='always')
def ol_ga_or(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return imt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, (types.Integer, types.Float)):
def impl(a, b):
return imt_func(a, as_ga(b))
return impl
elif isinstance(a, (types.Integer, types.Float)) and isinstance(b, types.Array) :
def impl(a, b):
return imt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a|b
return impl
class jit_func(object):
def __init__(self, ast_debug=False):
self.ast_debug = ast_debug
def __call__(self, func):
# Get the function source
fname = func.__name__
source = inspect.getsource(func)
source = '\n'.join(source.splitlines()[1:]) # remove the decorator first line.
# Re-write the ast
tree = ast.parse(source)
if self.ast_debug:
print('\n\n\n\n TRANFORMING FROM \n\n\n\n')
astpretty.pprint(tree)
tree = GATransformer().visit(tree)
ast.fix_missing_locations(tree)
if self.ast_debug:
print('\n\n\n\n TRANFORMING TO \n\n\n\n')
astpretty.pprint(tree)
# Compile the function
co = compile(tree, '<ast>', "exec")
locals = {}
exec(co, globals(), locals)
new_func = locals[fname]
# JIT the function
jit_func = numba.njit(new_func)
# Wrap the jitted function
def wrapper(*args, **kwargs):
return layout.MultiVector(value=jit_func(*[a.value for a in args], **kwargs))
return wrapper
class GATransformer(ast.NodeTransformer):
def visit_BinOp(self, node):
if isinstance(node.op, ast.Mult):
new_node = ast.Call(
func = ast.Name(id='ga_mul', ctx=ast.Load()),
args = [node.left, node.right],
keywords = []
)
new_node = GATransformer().visit(new_node)
return new_node
elif isinstance(node.op, ast.BitXor):
new_node = ast.Call(
func = ast.Name(id='ga_xor', ctx=ast.Load()),
args = [node.left, node.right],
keywords = []
)
new_node = GATransformer().visit(new_node)
return new_node
elif isinstance(node.op, ast.BitOr):
new_node = ast.Call(
func = ast.Name(id='ga_or', ctx=ast.Load()),
args = [node.left, node.right],
keywords = []
)
new_node = GATransformer().visit(new_node)
return new_node
elif isinstance(node.op, ast.Add):
new_node = ast.Call(
func = ast.Name(id='ga_add', ctx=ast.Load()),
args = [node.left, node.right],
keywords = []
)
new_node = GATransformer().visit(new_node)
return new_node
elif isinstance(node.op, ast.Sub):
new_node = ast.Call(
func = ast.Name(id='ga_sub', ctx=ast.Load()),
args = [node.left, node.right],
keywords = []
)
new_node = GATransformer().visit(new_node)
return new_node
return node
@jit_func(ast_debug=True)
def test_func(A, B, C):
op = (((A*B)*C)|(B^A)) - 3.1 - A - 7*B + 5 + C + 2.5 + (2^(A*B*C)^3) + (A|5)
return op
def slow_test_func(A, B, C):
op = (((A*B)*C)|(B^A)) - 3.1 - A - 7*B + 5 + C + 2.5 + (2^(A*B*C)^3) + (A|5)
return op
print(test_func(e1, e2, einf))
print(slow_test_func(e1, e2, einf))
nrepeats = 100000
start_time = time.time()
for i in range(nrepeats):
test_func(e1, e2, einf)
end_time = time.time()
print(1E6*(end_time - start_time)/nrepeats)
nrepeats = 100000
start_time = time.time()
for i in range(nrepeats):
slow_test_func(e1, e2, einf)
end_time = time.time()
print(1E6*(end_time - start_time)/nrepeats)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment