Skip to content

Instantly share code, notes, and snippets.

@icemelon
Created April 10, 2020 05:28
Show Gist options
  • Save icemelon/cd41746fefac55d033f06059df69c747 to your computer and use it in GitHub Desktop.
Save icemelon/cd41746fefac55d033f06059df69c747 to your computer and use it in GitHub Desktop.
test_nth
import numpy as np
import tvm
from tvm import relay
from tvm.relay.ty import TupleType, TensorType
from tvm.relay.prelude import Prelude
from tvm.runtime.container import ADT
def _get_relay_input_vars(input_shapes, prelude):
def _is_int_seq(seq):
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
def get_relay_ty(ishape):
if _is_int_seq(ishape) or len(ishape) == 0:
return TensorType(ishape)
elif isinstance(ishape, tuple):
return TupleType([get_relay_ty(elem) for elem in ishape])
elif isinstance(ishape, list):
assert len(ishape) > 0
elem_tys = [get_relay_ty(s) for s in ishape]
msg = "List elements should have identical types"
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
return prelude.l(elem_tys[0])
raise NotImplementedError("unsupported input type")
input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes]
return [relay.expr.var(name, type_annotation=itype)
for name, itype in input_types]
def convert_to_list_adt(py_lst, prelude):
adt_lst = ADT(prelude.nil.tag, [])
for arr in reversed(py_lst):
adt_lst = ADT(prelude.cons.tag, [relay.const(arr), adt_lst])
return adt_lst
def test_nth():
batch, hidden_size = 2, 4
input_name = "states"
input_shapes = [(input_name, [(batch, hidden_size), (batch, hidden_size)])]
state_list = [np.random.uniform(size=shape) for shape in input_shapes[0][1]]
mod = tvm.IRModule()
prelude = Prelude(mod)
adt_obj = convert_to_list_adt(state_list, prelude)
params = {input_name: adt_obj}
input_var = _get_relay_input_vars(input_shapes, prelude)[0]
mod["main"] = tvm.relay.Function([input_var], prelude.nth(input_var, relay.const(0)))
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm")
evaluator = executor.evaluate()
evaluator(**params)
test_nth()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment