Skip to content

Instantly share code, notes, and snippets.

@vinx13
Created August 3, 2018 10:02
Show Gist options
  • Save vinx13/08e7ea4cc0cca2a24c633a4697929157 to your computer and use it in GitHub Desktop.
Save vinx13/08e7ea4cc0cca2a24c633a4697929157 to your computer and use it in GitHub Desktop.
import sys
import logging
import tvm
from tvm import autotvm
import topi
import numpy as np
from topi.testing import conv2d_nchw_python
import functools
import operator
def intrin(c, h, w):
x = tvm.placeholder((c, h, w), name='x')
y = tvm.placeholder((c, h, w), name='y')
rc = tvm.reduce_axis((0, c), name='rc')
rh = tvm.reduce_axis((0, h), name='rh')
rw = tvm.reduce_axis((0, w), name='rw')
z = tvm.compute((1,), lambda NOTUSED: tvm.sum(x[rc,rh,rw] * y[rc,rh,rw], axis=[rc,rh,rw]))
def intrin_func(ins, outs):
return outs[0].vstore(0, 0.0) # TODO
with tvm.build_config(data_alignment=4, offset_factor=1) as cfg:
binds = {t: tvm.decl_buffer(t.shape, t.dtype, t.op.name,
data_alignment=cfg.data_alignment, offset_factor=cfg.offset_factor,
scope='local') for t in [x, y, z]}
return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds)
def conv2d(N, H, W, CI, CO, KH, KW, stride, padding):
data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding)
s = tvm.create_schedule([conv.op])
# inline padding
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
data, raw_data = pad_data, data
output = conv
OL = s.cache_write(conv, 'local')
# create cache stage
AA = s.cache_read(data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])
# tile and bind spatial axes
n, f, y, x = s[output].op.axis
bf, fi = s[output].split(f, factor=64)
tf, fi = s[output].split(fi, factor=8)
by, yi = s[output].split(y, factor=64)
ty, yi = s[output].split(yi, factor=8)
bx, xi = s[output].split(x, factor=64)
tx, xi = s[output].split(xi, factor=8)
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[output].reorder(n, bf, by, bx, tf, ty, tx, fi, yi, xi)
s[OL].compute_at(s[output], tx)
# tile and bind reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
rco, rci = s[OL].split(rc, factor=8)
ryo, ryi = s[OL].split(ry, factor=3)
rxo, rxi = s[OL].split(rx, factor=3)
s[OL].reorder(rco, ryo, rxo, n, f, y, x, rci, ryi, rxi)
s[OL].tensorize(rci, intrin(8, 7, 7))
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], x)
s[WL].compute_at(s[OL], x)
return s, [raw_data, kernel, conv]
N, H, W, CO, CI, KH, KW, strides, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
with tvm.target.create("cuda"):
s, arg_bufs = conv2d(
N, H, W, CO, CI, KH, KW, strides, padding)
print(tvm.lower(s, arg_bufs, simple_mode=True))
func = tvm.build(s, arg_bufs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment