Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Created January 3, 2018 02:08
Show Gist options
  • Save yzhliu/4ca12c1ad608997faf5cd5f6915f498a to your computer and use it in GitHub Desktop.
Save yzhliu/4ca12c1ad608997faf5cd5f6915f498a to your computer and use it in GitHub Desktop.
@generic.schedule_conv2d_nchw.register(["cpu"])
def schedule_conv2d(outs):
print('Run in x86 sch ...')
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
if 'conv2d_nchw' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
C = conv
print(C.op.axis)
print(C.op.reduce_axis)
print(data_pad.op.axis)
n, c, h, w = C.op.axis
rc, ry, rx = C.op.reduce_axis
s[C].reorder(n, c, rc, h, w, ry, rx)
r = s[C].fuse(ry, rx)
s[C].unroll(r)
xo, xi = s[C].split(w, factor=8)
s[C].parallel(c)
s[C].vectorize(xi)
s[C].pragma(n, "parallel_launch_point")
traverse(outs[0].op)
return s
"""
$ python test_topi_conv2d_nchw_intel.py
TVM: Initializing cython mode...
Run in x86 sch ...
[iter_var(nn, Range(min=0, extent=1)), iter_var(ff, Range(min=0, extent=64)), iter_var(yy, Range(min=0, extent=56)), iter_var(xx, Range(min=0, extent=56))]
[iter_var(rc, Range(min=0, extent=64)), iter_var(ry, Range(min=0, extent=3)), iter_var(rx, Range(min=0, extent=3))]
[iter_var(i0, Range(min=0, extent=1)), iter_var(i1, Range(min=0, extent=64)), iter_var(i2, Range(min=0, extent=58)), iter_var(i3, Range(min=0, extent=58))]
// attr [pad_temp] storage_scope = "global"
allocate pad_temp[float32 * 1 * 64 * 58 * 58]
produce pad_temp {
for (i1, 0, 64) {
for (i2, 0, 58) {
for (i3, 0, 58) {
pad_temp[((((i1*58) + i2)*58) + i3)] = tvm_if_then_else(((((1 <= i2) && (i2 < 57)) && (1 <= i3)) && (i3 < 57)), A[(((((i1*56) + i2)*56) + i3) + -57)], 0.000000f)
}
}
}
}
produce compute {
parallel (ff, 0, 64) {
for (yy.init, 0, 56) {
for (xx.outer.init, 0, 7) {
compute[ramp((((((ff*56) + yy.init)*7) + xx.outer.init)*8), 1, 8)] = x8(0.000000f)
}
}
for (rc, 0, 64) {
for (yy, 0, 56) {
for (xx.outer, 0, 7) {
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp(((((rc*58) + yy)*58) + (xx.outer*8)), 1, 8)]*x8(W[(((ff*64) + rc)*9)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 1), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 1)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 2), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 2)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 58), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 3)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 59), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 4)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 60), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 5)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 116), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 6)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 117), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 7)])))
compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] = (compute[ramp((((((ff*56) + yy)*7) + xx.outer)*8), 1, 8)] + (pad_temp[ramp((((((rc*58) + yy)*58) + (xx.outer*8)) + 118), 1, 8)]*x8(W[((((ff*64) + rc)*9) + 8)])))
}
}
}
}
}
Use memoize topi.tests.test_topi_conv2d.verify_con2d_nchw.get_ref_data.pkl(5, (1, 64, 56, 56), 'float32', 1, 1, (64, 64, 3, 3))
0.0134 secs/op
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment