Skip to content

Instantly share code, notes, and snippets.

@shoskensMagics
Last active March 8, 2023 10:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shoskensMagics/9ef15688f01cc107c162e450f9a574ef to your computer and use it in GitHub Desktop.
Save shoskensMagics/9ef15688f01cc107c162e450f9a574ef to your computer and use it in GitHub Desktop.
import tvm
from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.driver import tvmc
from tvm.relay.backend import Runtime, Executor
from tvm.relay.backend.contrib.uma import UMABackend
from tvm.relay.backend.contrib.uma.api.utils import PassPhase
from tvm.relay.dataflow_pattern import is_op, wildcard, is_constant
from tvm.target import Target
from tvm.tir.transform import PrimFuncPass
def qnn_conv2d_pattern():
dtype = "int8"
return is_op("qnn.conv2d")(
wildcard().has_dtype(dtype),
wildcard().has_dtype(dtype),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
@tvm.tir.transform.prim_func_pass(opt_level=2)
class VanillaPass(PrimFuncPass):
_EXTERNAL_FUNCTION_NAME = "my_ai_hw_conv2dnchw"
_TVM_BLOCK_MATCH_NAME = "conv2d_nchw"
def transform_function(
self,
func: tvm.tir.PrimFunc,
mod: tvm.ir.IRModule,
ctx: tvm.ir.transform.PassContext,
) -> tvm.tir.PrimFunc:
return self._my_ai_hw_conv2d_pass(func, mod, ctx)
@classmethod
def _my_ai_hw_conv2d_pass(cls, func, mod, ctx):
_loops = dict()
_handles = []
_entry_node = None
def _has_block(name: str, func: tvm.tir.PrimFunc) -> bool:
"""
Determine of a tir.block with `name` exists in `func`
"""
def _hb(op):
if isinstance(op, tvm.tir.Block):
_found_blocks.append(op.name_hint)
_found_blocks = []
tvm.tir.stmt_functor.post_order_visit(func.body, _hb)
return name in _found_blocks
def _detect_and_replace_conv2d(
func: tvm.tir.PrimFunc,
_mod: tvm.ir.IRModule,
_ctx: tvm.ir.transform.PassContext,
) -> tvm.tir.PrimFunc:
def _replace_conv2d(op):
if op == _entry_node:
irb = tvm.tir.ir_builder.create()
# Collection of buffer address
buffers = [b[1].data for b in _handles]
# extraction of loop offsets
for k, v in _loops.items():
assert v.min.value == 0
offset_order = ["co", "w", "h", "ci", "kh", "kw"]
offsets = [_loops[i].extent.value for i in offset_order]
args = buffers + offsets
irb.emit(tir_call(irb, True, cls._EXTERNAL_FUNCTION_NAME, *args))
irb_result = irb.get()
return irb_result
elif isinstance(op, tvm.tir.SeqStmt):
# Remove that pad block of TOPI's conv2DNCHW by only returning the 2nd statement
return op.seq[1]
return op
sch = tvm.tir.Schedule(func)
if _has_block(cls._TVM_BLOCK_MATCH_NAME, func):
conv2d_block = sch.get_block(cls._TVM_BLOCK_MATCH_NAME)
rv_loops = sch.get_loops(conv2d_block)
assert len(rv_loops) == 7
loops = dict(
n=rv_loops[0],
co=rv_loops[1],
h=rv_loops[2],
w=rv_loops[3],
ci=rv_loops[4],
kh=rv_loops[5],
kw=rv_loops[6],
)
_entry_node = sch.get(rv_loops[1])
_loops = {k: sch.get(v) for k, v in loops.items()}
_handles = func.buffer_map.items()
x = tvm.tir.stmt_functor.ir_transform(
func.body, None, _replace_conv2d, ["tir.For", "tir.SeqStmt"]
)
return func.with_body(x)
else:
return func
r = _detect_and_replace_conv2d(func, mod, ctx)
return r
def tir_call(ib: tvm.tir.ir_builder, extern: bool, name: str, *args):
"""
ib: ir_builder
extern: bool
True --> tvm.tir.call_extern
False --> tvm.tir.call_packed
name: str
function name
*args:
arguments for function call
"""
def buf_from_array(ib, arr, dtype):
# Allocate enough memory to store the whole array
var = ib.allocate("int32", (len(arr),), scope="global")
for i, v in enumerate(arr):
var[i] = v
# Declare a buffer, which is basically a view on the chunk of memory that we allocated
buf = tvm.tir.decl_buffer((len(arr),), dtype, data=var, scope="global")
return buf
if extern:
args = [i.data if isinstance(i, tvm.tir.Buffer) else i for i in args]
return tvm.tir.call_extern("int32", name, *args)
else:
args = [
buf_from_array(ib, i, "int32")
if isinstance(i, (tuple, list, tvm.ir.container.Array))
else i
for i in args
]
return tvm.tir.call_packed(name, *args)
@relay.op.strategy.override_native_generic_func("custom_conv2d_strategy")
def custom_conv2d_strategy(attrs, inputs, out_type, target):
return relay.op.strategy.generic.conv2d_strategy(
attrs, inputs[:2], out_type, target
)
class VanillaAcceleratorBackend(UMABackend):
"""UMA backend for the VanillaAccelerator accelerator."""
def __init__(self):
super().__init__()
self._register_pattern("qnn_conv2d", qnn_conv2d_pattern())
self._register_operator_strategy("qnn.conv2d", custom_conv2d_strategy)
self._register_tir_pass(PassPhase.TIR_PHASE_0, VanillaPass())
self._register_codegen(
fmt="c", includes=lambda: '#include "doesntmatter/conv2dnchw.cc"'
)
@property
def target_name(self):
return "vanilla_accelerator"
def test_vanilla():
tflite_path = download_testdata(
"https://github.com/mlcommons/tiny/raw/master"
"/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite",
"model.tflite",
)
tvmc_model = tvmc.load(str(tflite_path))
uma_backend = VanillaAcceleratorBackend()
uma_backend.register()
ir_module = uma_backend.partition(tvmc_model.mod, tvmc_model.params)
c_target = Target("c")
vanilla_target = Target(uma_backend.target_name, host=c_target)
eventual_target = [c_target, vanilla_target]
tvm.relay.build(
ir_module,
target=eventual_target,
params=tvmc_model.params,
runtime=Runtime("crt", {"system-lib": False}),
executor=Executor("aot", {"unpacked-api": True, "interface-api": "c"}),
),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment