Skip to content

Instantly share code, notes, and snippets.

@xqch1983
Created April 8, 2020 02:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xqch1983/17c5c3a4252b3efe849867158a880140 to your computer and use it in GitHub Desktop.
Save xqch1983/17c5c3a4252b3efe849867158a880140 to your computer and use it in GitHub Desktop.
import tvm
import topi
from tvm import autotvm
#from ...intrin import *
from topi.nn.pad import pad
from topi.nn.util import get_pad_tuple
from topi.util import simplify, get_const_tuple
import math
out_dtype=dtype="float32"
#switch for print stmt and debug
DEBUG=1
STRIPE_LEN = 16
TBATCH = 1
TIC = 16
TOC = 16
#prepare data for Input and Filtr
shape_Input = (1,128,38,38)
#[passed] shape_Input = (1,128,3,38)
Input = tvm.placeholder(shape_Input, name="Input",dtype=dtype)
shape_Kernel = (64,128,3,3)
Filter = tvm.placeholder(shape_Kernel, name="Filter",dtype=dtype)
stride=(1,1)
padding=(0,0)
dilation=(1,1)
dilation_h,dilation_w = dilation
stride_h,stride_w =stride
batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
#padding or not
#temp = pad(Input, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
shape_output = (batch, out_channel, out_height, out_width),
ofm = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda n, c_out, h, w: tvm.sum(
Input[n, rc, h * stride_h + ry * dilation_h,
w * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[c_out, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw",attrs={'stride': stride, 'padding': padding, 'dilation': dilation})
if DEBUG:
sch_global = tvm.create_schedule(ofm.op)
stmt = tvm.lower(sch_global,[Input,Filter,ofm],simple_mode=True)
print(".............initial after defition in compute part:\t \n ",stmt,"\n")
'''
https://discuss.tvm.ai/t/tensorize-with-stride-for-input-tensors/6018
[tqchen]
By default the tensor buffer declaration requires a compact buffer,
that means that the tensorized region need to be contiguous.
To relax the constraint, you can declare a buffer with symbolic strides
when declaring the tensor intrin, of course your low level instruction
must also support strided matrices as an input
翻译:
默认情况下,张量缓冲区声明需要一个紧凑缓冲区,
意思是说,待tensorize的区域必须是连续的。
要解除这种约束,在声明张量内因时,可以声明一个带有符号跨度的buffer,例如,ww = tvm.var("ww")
当然是您的底层指令还必须支持跨步矩阵作为输入
How to use tensorize https://discuss.tvm.ai/t/how-to-use-tensorize/424
There are two cases here.
1, describe the compute logic, we can call it original compute
2, you need to do some schedule (split) to expose the axis you wanna tensorize, and mark it with tensorize api
3, you need to describe intrinsic pattern, which includes two parts, one is a clone of original compute, another is the intrinsic pattern you wanna use
4, after that, in scheduleOps, tensorize will try to do pattern match for original compute and clone compute, if success, will replace the marked axis with intrinsic.
'''
def intrin_partial_conv2d_ohow_toc(outs,shape_data,shape_kernel,tohow,toc):
data_intrin = tvm.placeholder(shape_data, name='data')
kernel_intrin = tvm.placeholder(shape_kernel, name='kernel')
batch, in_channel, _, _ = shape_data #[1,128,38,38]
num_filter, _, kernel_h, kernel_w = shape_kernel # [16,128,3,3]
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
rc = tvm.reduce_axis((0, in_channel), name='rc')
#shape_out = [1,toc,1,tohow] #[1,16,1,36]
ofm_intrin = tvm.compute(
(batch,num_filter,1,tohow),
lambda n, c, h, w: tvm.sum(
data_intrin[n, rc, h * stride_h + ry * dilation_h,
w * stride_w + rx * dilation_w].astype(out_dtype) *
kernel_intrin[c, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx]), tag="conv2d_nchw",attrs={'stride': stride, 'padding': padding, 'dilation': dilation})
if DEBUG:
schedule_intrin= tvm.create_schedule(ofm_intrin.op)
dom_map = tvm.schedule.InferBound(schedule_intrin)
print("[dom_map] in intrinsic:\t\n",dom_map,"\n")
nn, cc , hh, ww = tvm.var("nn"), tvm.var("cc"), tvm.var("hh"), tvm.var("ww")
#ww = tvm.floordiv((ww*36),36)
tt = tvm.floordiv((tvm.floormod((ww*16), 36) + 15), 36)
data_buf = tvm.decl_buffer(data_intrin.shape, data_intrin.dtype,
name="DATA",
offset_factor=1 ,strides=[nn, cc,hh,ww])
kernel_buf = tvm.decl_buffer(kernel_intrin.shape, kernel_intrin.dtype,
name="KERNEL",
offset_factor=1 )
ofm_nn, ofm_cc ,ofm_hh, ofm_ww = tvm.var("ofm_nn"), tvm.var("ofm_cc"), tvm.var("ofm_hh"), tvm.var("ofm_ww")
ofm_buf = tvm.decl_buffer(ofm_intrin.shape, ofm_intrin.dtype,
name="OFM",
offset_factor=1,strides=[ofm_cc, ofm_nn,ofm_hh,ofm_ww])
#,strides=[1,64,36,36])
if DEBUG:
sch_temp = tvm.create_schedule(ofm_intrin.op)
temp_n,temp_c,temp_h,temp_w = sch_temp[ofm_intrin].op.axis
sch_temp[ofm_intrin].reorder( temp_n,temp_w,temp_h,temp_c)
stmt_temp=tvm.lower(sch_temp,[data_intrin,kernel_intrin,ofm_intrin],simple_mode=True)
print("..........intrin_partial_conv2d stmt 【inside of definition 】.......",stmt_temp,"\n")
def intrin_func(ins, outs):
ib = tvm.ir_builder.create()
ib.emit(tvm.stmt.stmt_seq(
tvm.call_extern("int32", "SailStartCommand", "cdma")))
return ib.get()
def intrin_func2(ins, outs):
ib = tvm.ir_builder.create()
aa, bb = ins
cc = outs[0]
print(".............intrin_fun2......:\n",aa.strides, bb.strides, cc.strides)
ib.emit(tvm.call_extern("int32", "gemv_update",
cc.access_ptr("r"),
aa.access_ptr("r"),
bb.access_ptr("r"),
cc.access_ptr("w")))
# ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr')))
return ib.get()
with tvm.build_config(offset_factor=1):
print(".............intrin_fun2......:\n",data_buf.strides, kernel_buf.strides, ofm_buf.strides)
#【PASS】
#return tvm.decl_tensor_intrin(ofm_intrin.op, intrin_func2,binds={data_intrin: data_buf, kernel_intrin: kernel_buf, ofm_intrin: ofm_buf}, name="sp_conv2d")
return tvm.decl_tensor_intrin(ofm_intrin.op, intrin_func,binds={data_intrin: data_buf, kernel_intrin: kernel_buf, ofm_intrin: ofm_buf}, name="sp_conv2d")
dom_map = tvm.schedule.InferBound(sch_global)
print("[dom_map-1 before schedule]:\t\n",dom_map,"\n")
n,oc,oh,ow = sch_global[ofm].op.axis
ic, kh, kw = sch_global[ofm].op.reduce_axis
N,IC,IH,IW = Input.shape
KN,KC,KH,KW = Filter.shape
n_op,oc,oh,ow = ofm.op.axis
ohow = sch_global[ofm].fuse(oh, ow)
Ntohow, tohow = sch_global[ofm].split(ohow, factor=STRIPE_LEN)
Ntoc,toc = sch_global[ofm].split(oc, factor = TOC)# [,16]
#Ntic, tic = sch_global[ofm].split(ic, factor = TIC)
sch_global[ofm].reorder(n,Ntoc, Ntohow,tohow,toc,ic,kh,kw)
if DEBUG:
dom_map = tvm.schedule.InferBound(sch_global)
print("..........after fuse && split && reorder ......dom_map: dom_map-1-2:\n",dom_map,"\n")
temp_stmt3 = tvm.lower(sch_global, (Input, Filter,ofm),simple_mode=True)
print("...........after [oh,ow] fuse && split[Ntohow,tohow] [Ntoc,toc] && reorder (n,Ntoc, Ntohow,tohow,toc,ic,kh,kw).....lower stmt: ",temp_stmt3,"\n")
# prepare for tensorize
shape_data_intrinsic =[ N,IC,KH,38] #[1,128,3,38]
#shape_data = [1,128,3,38]
shape_kernel_intrinsic = [TOC,IC,KH,KW] #[16,128,3,3]
#shape_kernel = [16,128,3,3]
#ideal output [1,16,1,36 ]
# impl tensorize
partial_conv2d = intrin_partial_conv2d_ohow_toc(ofm,shape_data_intrinsic,shape_kernel_intrinsic,tohow=STRIPE_LEN, toc=TOC)
sch_global[ofm].tensorize(tohow, partial_conv2d)
if DEBUG:
sch_global = sch_global.normalize()
dom_map = tvm.schedule.InferBound(sch_global)
print("[after tensorize dom_map-2]:\t\n",dom_map,"\n")
finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
out_dom, in_dom = finfer(sch_global[ofm], dom_map)
print("[output_dom]:\t\n",out_dom,"\n")
print("[input_dom]:\t\n",in_dom,"\n")
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
body = fmatch(sch_global[ofm], out_dom, in_dom, partial_conv2d)
print("body[0]:\t",tvm.ir_pass.CanonicalSimplify(body[0]),"\n")
print("partial_conv2d.op.body[0]:\t",tvm.ir_pass.CanonicalSimplify(partial_conv2d.op.body[0]),"\n")
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
tvm.ir_pass.CanonicalSimplify(partial_conv2d.op.body[0]))
print("***********Success on body[0] vs partial_conv2d.op.body[0]*************")
print("......................cutting line .........................\n")
lowered_stmt = tvm.lower(sch_global, (Input, Filter,ofm), simple_mode=True)
print(lowered_stmt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment