Created
April 8, 2020 02:28
-
-
Save xqch1983/17c5c3a4252b3efe849867158a880140 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tvm | |
import topi | |
from tvm import autotvm | |
#from ...intrin import * | |
from topi.nn.pad import pad | |
from topi.nn.util import get_pad_tuple | |
from topi.util import simplify, get_const_tuple | |
import math | |
out_dtype=dtype="float32" | |
#switch for print stmt and debug | |
DEBUG=1 | |
STRIPE_LEN = 16 | |
TBATCH = 1 | |
TIC = 16 | |
TOC = 16 | |
#prepare data for Input and Filtr | |
shape_Input = (1,128,38,38) | |
#[passed] shape_Input = (1,128,3,38) | |
Input = tvm.placeholder(shape_Input, name="Input",dtype=dtype) | |
shape_Kernel = (64,128,3,3) | |
Filter = tvm.placeholder(shape_Kernel, name="Filter",dtype=dtype) | |
stride=(1,1) | |
padding=(0,0) | |
dilation=(1,1) | |
dilation_h,dilation_w = dilation | |
stride_h,stride_w =stride | |
batch, in_channel, in_height, in_width = Input.shape | |
num_filter, channel, kernel_h, kernel_w = Filter.shape | |
# compute the output shape | |
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 | |
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 | |
pad_top, pad_left, pad_down, pad_right = get_pad_tuple( | |
padding, (dilated_kernel_h, dilated_kernel_w)) | |
out_channel = num_filter | |
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) | |
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) | |
# compute graph | |
pad_before = [0, 0, pad_top, pad_left] | |
pad_after = [0, 0, pad_down, pad_right] | |
#padding or not | |
#temp = pad(Input, pad_before, pad_after, name="pad_temp") | |
rc = tvm.reduce_axis((0, in_channel), name='rc') | |
ry = tvm.reduce_axis((0, kernel_h), name='ry') | |
rx = tvm.reduce_axis((0, kernel_w), name='rx') | |
shape_output = (batch, out_channel, out_height, out_width), | |
ofm = tvm.compute( | |
(batch, out_channel, out_height, out_width), | |
lambda n, c_out, h, w: tvm.sum( | |
Input[n, rc, h * stride_h + ry * dilation_h, | |
w * stride_w + rx * dilation_w].astype(out_dtype) * | |
Filter[c_out, rc, ry, rx].astype(out_dtype), | |
axis=[rc, ry, rx]), tag="conv2d_nchw",attrs={'stride': stride, 'padding': padding, 'dilation': dilation}) | |
if DEBUG: | |
sch_global = tvm.create_schedule(ofm.op) | |
stmt = tvm.lower(sch_global,[Input,Filter,ofm],simple_mode=True) | |
print(".............initial after defition in compute part:\t \n ",stmt,"\n") | |
''' | |
https://discuss.tvm.ai/t/tensorize-with-stride-for-input-tensors/6018 | |
[tqchen] | |
By default the tensor buffer declaration requires a compact buffer, | |
that means that the tensorized region need to be contiguous. | |
To relax the constraint, you can declare a buffer with symbolic strides | |
when declaring the tensor intrin, of course your low level instruction | |
must also support strided matrices as an input | |
翻译: | |
默认情况下,张量缓冲区声明需要一个紧凑缓冲区, | |
意思是说,待tensorize的区域必须是连续的。 | |
要解除这种约束,在声明张量内因时,可以声明一个带有符号跨度的buffer,例如,ww = tvm.var("ww") | |
当然是您的底层指令还必须支持跨步矩阵作为输入 | |
How to use tensorize https://discuss.tvm.ai/t/how-to-use-tensorize/424 | |
There are two cases here. | |
1, describe the compute logic, we can call it original compute | |
2, you need to do some schedule (split) to expose the axis you wanna tensorize, and mark it with tensorize api | |
3, you need to describe intrinsic pattern, which includes two parts, one is a clone of original compute, another is the intrinsic pattern you wanna use | |
4, after that, in scheduleOps, tensorize will try to do pattern match for original compute and clone compute, if success, will replace the marked axis with intrinsic. | |
''' | |
def intrin_partial_conv2d_ohow_toc(outs,shape_data,shape_kernel,tohow,toc): | |
data_intrin = tvm.placeholder(shape_data, name='data') | |
kernel_intrin = tvm.placeholder(shape_kernel, name='kernel') | |
batch, in_channel, _, _ = shape_data #[1,128,38,38] | |
num_filter, _, kernel_h, kernel_w = shape_kernel # [16,128,3,3] | |
ry = tvm.reduce_axis((0, kernel_h), name='ry') | |
rx = tvm.reduce_axis((0, kernel_w), name='rx') | |
rc = tvm.reduce_axis((0, in_channel), name='rc') | |
#shape_out = [1,toc,1,tohow] #[1,16,1,36] | |
ofm_intrin = tvm.compute( | |
(batch,num_filter,1,tohow), | |
lambda n, c, h, w: tvm.sum( | |
data_intrin[n, rc, h * stride_h + ry * dilation_h, | |
w * stride_w + rx * dilation_w].astype(out_dtype) * | |
kernel_intrin[c, rc, ry, rx].astype(out_dtype), | |
axis=[rc, ry, rx]), tag="conv2d_nchw",attrs={'stride': stride, 'padding': padding, 'dilation': dilation}) | |
if DEBUG: | |
schedule_intrin= tvm.create_schedule(ofm_intrin.op) | |
dom_map = tvm.schedule.InferBound(schedule_intrin) | |
print("[dom_map] in intrinsic:\t\n",dom_map,"\n") | |
nn, cc , hh, ww = tvm.var("nn"), tvm.var("cc"), tvm.var("hh"), tvm.var("ww") | |
#ww = tvm.floordiv((ww*36),36) | |
tt = tvm.floordiv((tvm.floormod((ww*16), 36) + 15), 36) | |
data_buf = tvm.decl_buffer(data_intrin.shape, data_intrin.dtype, | |
name="DATA", | |
offset_factor=1 ,strides=[nn, cc,hh,ww]) | |
kernel_buf = tvm.decl_buffer(kernel_intrin.shape, kernel_intrin.dtype, | |
name="KERNEL", | |
offset_factor=1 ) | |
ofm_nn, ofm_cc ,ofm_hh, ofm_ww = tvm.var("ofm_nn"), tvm.var("ofm_cc"), tvm.var("ofm_hh"), tvm.var("ofm_ww") | |
ofm_buf = tvm.decl_buffer(ofm_intrin.shape, ofm_intrin.dtype, | |
name="OFM", | |
offset_factor=1,strides=[ofm_cc, ofm_nn,ofm_hh,ofm_ww]) | |
#,strides=[1,64,36,36]) | |
if DEBUG: | |
sch_temp = tvm.create_schedule(ofm_intrin.op) | |
temp_n,temp_c,temp_h,temp_w = sch_temp[ofm_intrin].op.axis | |
sch_temp[ofm_intrin].reorder( temp_n,temp_w,temp_h,temp_c) | |
stmt_temp=tvm.lower(sch_temp,[data_intrin,kernel_intrin,ofm_intrin],simple_mode=True) | |
print("..........intrin_partial_conv2d stmt 【inside of definition 】.......",stmt_temp,"\n") | |
def intrin_func(ins, outs): | |
ib = tvm.ir_builder.create() | |
ib.emit(tvm.stmt.stmt_seq( | |
tvm.call_extern("int32", "SailStartCommand", "cdma"))) | |
return ib.get() | |
def intrin_func2(ins, outs): | |
ib = tvm.ir_builder.create() | |
aa, bb = ins | |
cc = outs[0] | |
print(".............intrin_fun2......:\n",aa.strides, bb.strides, cc.strides) | |
ib.emit(tvm.call_extern("int32", "gemv_update", | |
cc.access_ptr("r"), | |
aa.access_ptr("r"), | |
bb.access_ptr("r"), | |
cc.access_ptr("w"))) | |
# ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) | |
return ib.get() | |
with tvm.build_config(offset_factor=1): | |
print(".............intrin_fun2......:\n",data_buf.strides, kernel_buf.strides, ofm_buf.strides) | |
#【PASS】 | |
#return tvm.decl_tensor_intrin(ofm_intrin.op, intrin_func2,binds={data_intrin: data_buf, kernel_intrin: kernel_buf, ofm_intrin: ofm_buf}, name="sp_conv2d") | |
return tvm.decl_tensor_intrin(ofm_intrin.op, intrin_func,binds={data_intrin: data_buf, kernel_intrin: kernel_buf, ofm_intrin: ofm_buf}, name="sp_conv2d") | |
dom_map = tvm.schedule.InferBound(sch_global) | |
print("[dom_map-1 before schedule]:\t\n",dom_map,"\n") | |
n,oc,oh,ow = sch_global[ofm].op.axis | |
ic, kh, kw = sch_global[ofm].op.reduce_axis | |
N,IC,IH,IW = Input.shape | |
KN,KC,KH,KW = Filter.shape | |
n_op,oc,oh,ow = ofm.op.axis | |
ohow = sch_global[ofm].fuse(oh, ow) | |
Ntohow, tohow = sch_global[ofm].split(ohow, factor=STRIPE_LEN) | |
Ntoc,toc = sch_global[ofm].split(oc, factor = TOC)# [,16] | |
#Ntic, tic = sch_global[ofm].split(ic, factor = TIC) | |
sch_global[ofm].reorder(n,Ntoc, Ntohow,tohow,toc,ic,kh,kw) | |
if DEBUG: | |
dom_map = tvm.schedule.InferBound(sch_global) | |
print("..........after fuse && split && reorder ......dom_map: dom_map-1-2:\n",dom_map,"\n") | |
temp_stmt3 = tvm.lower(sch_global, (Input, Filter,ofm),simple_mode=True) | |
print("...........after [oh,ow] fuse && split[Ntohow,tohow] [Ntoc,toc] && reorder (n,Ntoc, Ntohow,tohow,toc,ic,kh,kw).....lower stmt: ",temp_stmt3,"\n") | |
# prepare for tensorize | |
shape_data_intrinsic =[ N,IC,KH,38] #[1,128,3,38] | |
#shape_data = [1,128,3,38] | |
shape_kernel_intrinsic = [TOC,IC,KH,KW] #[16,128,3,3] | |
#shape_kernel = [16,128,3,3] | |
#ideal output [1,16,1,36 ] | |
# impl tensorize | |
partial_conv2d = intrin_partial_conv2d_ohow_toc(ofm,shape_data_intrinsic,shape_kernel_intrinsic,tohow=STRIPE_LEN, toc=TOC) | |
sch_global[ofm].tensorize(tohow, partial_conv2d) | |
if DEBUG: | |
sch_global = sch_global.normalize() | |
dom_map = tvm.schedule.InferBound(sch_global) | |
print("[after tensorize dom_map-2]:\t\n",dom_map,"\n") | |
finfer = tvm.get_global_func("test.op.InferTensorizeRegion") | |
out_dom, in_dom = finfer(sch_global[ofm], dom_map) | |
print("[output_dom]:\t\n",out_dom,"\n") | |
print("[input_dom]:\t\n",in_dom,"\n") | |
fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") | |
body = fmatch(sch_global[ofm], out_dom, in_dom, partial_conv2d) | |
print("body[0]:\t",tvm.ir_pass.CanonicalSimplify(body[0]),"\n") | |
print("partial_conv2d.op.body[0]:\t",tvm.ir_pass.CanonicalSimplify(partial_conv2d.op.body[0]),"\n") | |
assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]), | |
tvm.ir_pass.CanonicalSimplify(partial_conv2d.op.body[0])) | |
print("***********Success on body[0] vs partial_conv2d.op.body[0]*************") | |
print("......................cutting line .........................\n") | |
lowered_stmt = tvm.lower(sch_global, (Input, Filter,ofm), simple_mode=True) | |
print(lowered_stmt) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment