Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Created August 28, 2019 22:07
Show Gist options
  • Save yzhliu/5439777c6b2d8fd3f5aeef43ee2048e5 to your computer and use it in GitHub Desktop.
Save yzhliu/5439777c6b2d8fd3f5aeef43ee2048e5 to your computer and use it in GitHub Desktop.
import tvm
import topi
from topi.util import get_const_tuple
import numpy as np
from topi.nn.pad import pad
# on a2: python3 -m tvm.exec.rpc_server --port=8499
# target = 'llvm -mcpu=core-avx2'
# target = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.4a,+fp16fml,+fullfp16'
target = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+fullfp16,+fp-armv8,+dotprod,+crc,+crypto,+neon'
dtype = 'float16'
def get_fp32_len():
return 8
def _fallback_schedule(in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
WPAD, HPAD = padding
WSTR, HSTR = strides
simd_width = get_fp32_len()
out_width = (width + 2 * WPAD - filter_width) // WSTR + 1
oc_bn = 1
for bn in range(simd_width, 0, -1):
if num_filter % bn == 0:
oc_bn = bn
break
ic_bn = 1
for bn in range(oc_bn, 0, -1):
if in_channel % bn == 0:
ic_bn = bn
break
reg_n = 1
for n in range(31, 0, -1):
if out_width % n == 0:
reg_n = n
break
return ic_bn, oc_bn, reg_n, False
def conv_compute(data, kernel, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
out_dtype = data.dtype
dilation_h, dilation_w = 1, 1
HPAD, WPAD = padding
HSTR, WSTR = strides
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
pad_height = in_height + 2 * HPAD
pad_width = in_width + 2 * WPAD
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1
out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1
# pack data
DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
else:
data_pad = data
# fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = _fallback_schedule(in_channel, height, width, num_filter,
filter_height, filter_width, padding, strides)
shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width)
data_vec = tvm.compute(shape,
lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w],
name='data_vec')
# pack kernel
shape = (num_filter//oc_bn, in_channel//ic_bn,
kernel_height, kernel_width, ic_bn, oc_bn)
kernel_vec = tvm.compute(shape,
lambda CO, CI, h, w, ci, co:
kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w],
name='kernel_vec')
# convolution
oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn)
unpack_shape = (batch_size, num_filter, out_height, out_width)
ic = tvm.reduce_axis((0, in_channel), name='ic')
kh = tvm.reduce_axis((0, kernel_height), name='kh')
kw = tvm.reduce_axis((0, kernel_width), name='kw')
conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn,
ow*WSTR+kw*dilation_w].astype(out_dtype) *
kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn,
oc_block].astype(out_dtype),
axis=[ic, kh, kw]), name='conv')
unpack = tvm.compute(unpack_shape,
lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn]
.astype(out_dtype),
name='output_unpack',
tag='conv2d_nchw')
return unpack
def conv_schedule(C, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
s = tvm.create_schedule(C.op)
op = C.op
output = op.output(0)
conv_out = op.input_tensors[0]
kernel_vec = conv_out.op.input_tensors[1]
kernel = kernel_vec.op.input_tensors[0]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data_vec = conv_out.op.input_tensors[0]
data = data_vec.op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
_, _, kh, kw = get_const_tuple(kernel.shape)
# fetch schedule
ic_bn, oc_bn, reg_n, unroll_kw = _fallback_schedule(in_channel, height, width, num_filter,
filter_height, filter_width, padding, strides)
# no stride and padding info here
HPAD, WPAD = padding
DOPAD = (HPAD != 0 or WPAD != 0)
A, W = data, kernel_vec
A0, A1 = data_pad, data_vec
# schedule data
if DOPAD:
s[A0].compute_inline()
batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
parallel_axis = s[A1].fuse(ic_chunk, ih)
s[A1].parallel(parallel_axis)
# schedule kernel pack
oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
if oc_bn > 1:
s[W].vectorize(oc_block)
parallel_axis = s[W].fuse(oc_chunk, oh)
s[W].parallel(parallel_axis)
# schedule conv
C, O0 = conv_out, output
CC = s.cache_write(C, 'global')
_, oc_chunk, oh, ow, oc_block = s[C].op.axis
ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
s[C].fuse(oc_chunk, oh)
s[C].vectorize(oc_block)
s[CC].compute_at(s[C], ow_chunk)
_, oc_chunk, oh, ow, oc_block = s[CC].op.axis
ic, kh, kw = s[CC].op.reduce_axis
ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)
if unroll_kw:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block)
s[CC].unroll(kw)
else:
s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block)
s[CC].fuse(oc_chunk, oh)
s[CC].vectorize(oc_block)
s[CC].unroll(ow_block)
return s
def run_conv2d(batch_size, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
A = tvm.placeholder((batch_size, in_channel, height, width), name='A', dtype=dtype)
W = tvm.placeholder((num_filter, in_channel, filter_height, filter_width), name='W', dtype=dtype)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
from topi.testing.conv2d_nchw_python import conv2d_nchw_python
conv_np = conv2d_nchw_python(a_np, w_np, stride=(1,1), padding=(1,1))
return a_np, w_np, conv_np
a_np, w_np, conv_np = get_ref_data()
C = conv_compute(A, W, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides)
s = conv_schedule(C, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides)
# s = tvm.create_schedule(C.op)
print(tvm.lower(s, [A, W, C], simple_mode=True))
from tvm import rpc
from tvm.contrib import util
host = '0.0.0.0'
port = 8499
remote = rpc.connect(host, port)
ctx = remote.cpu()
# ctx = tvm.cpu()
func = tvm.build(s, [A, W, C], target=target)
func.save('conv.s')
temp = util.tempdir()
path = temp.relpath('lib.tar')
func.export_library(path)
remote.upload(path)
func = remote.load_module('lib.tar')
conv = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx)
time_f = func.time_evaluator(func.entry_name, ctx, number=50)
cost_conv = time_f(tvm.nd.array(a_np, ctx), tvm.nd.array(w_np, ctx), conv).mean
print('conv: %g ms/op' % (cost_conv * 1000.0))
# np.testing.assert_allclose(conv.asnumpy(), conv_np, rtol=1e-5)
if __name__ == "__main__":
run_conv2d(batch_size=1, in_channel=64, height=56, width=56, num_filter=64, filter_height=3, filter_width=3,
padding=(1, 1), strides=(1, 1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment