Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created March 8, 2019 23:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bwasti/ff8f754c034b8005cfedaf25f8ce2b17 to your computer and use it in GitHub Desktop.
Save bwasti/ff8f754c034b8005cfedaf25f8ce2b17 to your computer and use it in GitHub Desktop.
import tvm
from tvm import relay
import torch
import torch.nn.functional as F
import inspect
import ast
import numpy as np
_parsed_functions = dict()
def jit_assert(cond, msg="[see stack]"):
if not cond:
raise Exception(msg)
def get_methods(mod):
m = dict()
for method in dir(mod):
if method.startswith("_"):
continue
m[id(getattr(mod, method))] = method
return m
_torch_methods = get_methods(torch)
_tensor_methods = get_methods(torch.tensor)
_functional_methods = get_methods(F)
class RelayParser(ast.NodeVisitor):
def __init__(self, globals_={}):
self.symbols = {} # Symbol table
self.binds = {}
self.globals = globals_
self.inputs = []
self.input_types = []
super().__init__()
def get_id_from_node_(self, node):
if isinstance(node, ast.Attribute):
obj = self.get_id_from_node_(node.value)
return getattr(obj, node.attr)
if isinstance(node, ast.Name):
return self.globals[node.id]
jit_assert(False, "Cannot get id from node {}".format(ast.dump(node)))
def get_id_from_node(self, node):
return id(self.get_id_from_node_(node))
def torch_builtin(self, func_name, args, keywords):
if func_name == "ones":
return relay.const(tvm.ndarray.array(
np.ones(ast.literal_eval(args[0])).astype(np.float32)
))
if func_name == "relu":
return relay.nn.relu(self.visit(args[0]))
if func_name == "conv2d":
stride = relay.const(1)
#pad = relay.const(0)
#for k in keywords:
# if k.arg == "stride":
# stride = self.visit(k.value)
# if k.arg == "padding":
# pad = self.visit(k.value)
r = relay.nn.conv2d(self.visit(args[0]), self.visit(args[1]),
strides=[stride, stride])
#, padding=(pad, pad))
return r
if func_name == "batch_norm":
r = relay.nn.batch_norm(
self.visit(args[0]), # input
self.visit(args[3]), # weight
self.visit(args[4]), # bias
self.visit(args[1]), # mean
self.visit(args[2]) # var
)
return r[0]
jit_assert(False, "Couldn't not match {} to torch builtin function".format(func_name))
def relay_func_from_node(self, node, args, keywords):
f_id = self.get_id_from_node(node)
if f_id in _torch_methods:
return self.torch_builtin(
_torch_methods[f_id], args, keywords)
if f_id in _functional_methods:
return self.torch_builtin(
_functional_methods[f_id], args, keywords)
jit_assert(False)
def add_symbol(self, name, var, expr=None):
if name in self.symbols.keys():
old = str(self.symbols[name])
new = str(var)
jit_assert(False, "Symbol conflict [{}] {} -> {}.".format(key, old, new))
self.symbols[name] = var
if expr:
self.binds[var] = expr
def generic_visit(self, node):
jit_assert(False, "Couldn't parse node {}".format(ast.dump(node)))
def visit_Module(self, node):
jit_assert(len(node.body) == 1, \
"Only one-function source code will be fed to this parser!")
return self.visit(node.body[0])
def visit_FunctionDef(self, node):
jit_assert(node.name not in _parsed_functions, "Conflicting function name {}".format(node.name))
for i, arg in enumerate(node.args.args):
var = relay.Var(arg.arg)
self.add_symbol(arg.arg, var)
self.inputs.append(var)
ls = [self.visit(stmt) for stmt in node.body]
func = relay.Function(self.inputs, relay.bind(self.output, self.binds))
_parsed_functions[node.name] = func
return func
def visit_Return(self, node):
self.output = self.visit(node.value)
return self.output
def visit_Name(self, node):
name = node.id
jit_assert(name in self.symbols, "Couldn't find variable '{}'".format(name))
return self.symbols[name]
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
x = self.visit(node.left)
y = self.visit(node.right)
return relay.op.add(x,y)
if isinstance(node.op, ast.Mult):
x = self.visit(node.left)
y = self.visit(node.right)
return relay.op.multiply(x,y)
jit_assert(False)
def visit_Call(self, node):
return self.relay_func_from_node(node.func, node.args, node.keywords)
def visit_Assign(self, node):
rhs = self.visit(node.value)
rhs = relay.bind(rhs, self.binds)
lhs = node.targets[0]
lhs_var = relay.var(lhs.id)
self.add_symbol(lhs.id, lhs_var, rhs)
return rhs
class RelayFunc(object):
def __init__(self, relay_expr, inputs):
self.expr = relay_expr
self.compiled = {}
self.grad = None
self.grad_compiled = {}
self.inputs = inputs
def compile_func(self, expr, shapes_list, mod=None):
inputs = []
for i, shape in enumerate(shapes_list):
v = relay.Var(self.inputs[i].name_hint, relay.TensorType(shapes_list[i]))
inputs.append(v)
expr = relay.Function(inputs, relay.Call(expr, inputs))
expr = relay.ir_pass.infer_type(expr)
expr = expr.body.op
graph = relay.create_executor('graph', mod=mod)
def f(*args):
return graph.evaluate(expr)(*args).asnumpy()
return f
def __call__(self, *inputs, torch_mode=False):
shapes = tuple([i.shape if hasattr(i, 'shape') else () for i in inputs])
print(shapes)
if shapes not in self.compiled:
try:
self.compiled[shapes] = self.compile_func(self.expr, shapes)
except Exception as e:
print("While compiling\n", self.expr)
raise e
if torch_mode:
return torch.tensor(self.compiled[shapes](*[i.detach().numpy() if hasattr(i, 'numpy') else i for i in inputs]))
return self.compiled[shapes](*inputs)
def backward(self, *inputs):
if not self.grad:
self.grad = relay.ir_pass.gradient(relay_expr)
shapes = tuple([i.shape for i in inputs])
if shapes not in self.grad_compiled:
grad_compiled[shapes] = self.compile_func(self.grad, shapes)
return self.grad_compiled[shapes](*inputs)
def script(fn):
parser = RelayParser(fn.__globals__)
sauce = inspect.getsource(fn)
relay_expr = None
try:
relay_expr = parser.visit(ast.parse(sauce))
except Exception as e:
print("Hit exception while parsing:\n\n{}\n".format(sauce))
raise e
f = None
try:
f = RelayFunc(relay_expr, parser.inputs)
except Exception as e:
raise e
return f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment