Skip to content

Instantly share code, notes, and snippets.

Created April 10, 2020 05:28
Show Gist options
  • Save icemelon/cd41746fefac55d033f06059df69c747 to your computer and use it in GitHub Desktop.
Save icemelon/cd41746fefac55d033f06059df69c747 to your computer and use it in GitHub Desktop.
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ty import TupleType, TensorType
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT
def _get_relay_input_vars(input_shapes, prelude):
def _is_int_seq(seq):
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
def get_relay_ty(ishape):
if _is_int_seq(ishape) or len(ishape) == 0:
return TensorType(ishape)
elif isinstance(ishape, tuple):
return TupleType([get_relay_ty(elem) for elem in ishape])
elif isinstance(ishape, list):
assert len(ishape) > 0
elem_tys = [get_relay_ty(s) for s in ishape]
msg = "List elements should have identical types"
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
return prelude.l(elem_tys[0])
raise NotImplementedError("unsupported input type")
input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
return [relay.expr.var(name, type_annotation=itype)
for name, itype in input_types]
def convert_to_list_adt(py_lst, prelude):
adt_lst = ADT(prelude.nil.tag, [])
for arr in reversed(py_lst):
adt_lst = ADT(prelude.cons.tag, [relay.const(arr), adt_lst])
return adt_lst
def test_nth():
batch, hidden_size = 2, 4
input_name = "states"
input_shapes = [(input_name, [(batch, hidden_size), (batch, hidden_size)])]
state_list = [np.random.uniform(size=shape) for shape in input_shapes[0][1]]
mod = tvm.IRModule()
prelude = Prelude(mod)
adt_obj = convert_to_list_adt(state_list, prelude)
params = {input_name: adt_obj}
input_var = _get_relay_input_vars(input_shapes, prelude)[0]
mod["main"] = tvm.relay.Function([input_var], prelude.nth(input_var, relay.const(0)))
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
evaluator = executor.evaluate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment