Skip to content

Instantly share code, notes, and snippets.

@masahi
Created October 11, 2017 02:24
Show Gist options
  • Save masahi/706fd19cb4c589bc7ea59d68cde99a8a to your computer and use it in GitHub Desktop.
Save masahi/706fd19cb4c589bc7ea59d68cde99a8a to your computer and use it in GitHub Desktop.
# pylint: disable=invalid-name, unused-argument
"""Reduction ops"""
from __future__ import absolute_import
import tvm
import topi
import topi.cuda
from . import registry as reg
from .registry import OpPattern
def _schedule_reduce(_, outs, target):
"""Generic schedule for reduce"""
if target == "cuda" or target == "opencl":
return topi.cuda.schedule_reduce(outs)
assert target.startswith("llvm")
s = tvm.create_schedule([x.op for x in outs])
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
_fschedule_reduce = tvm.convert(_schedule_reduce)
def _compute_reduce(f):
"""auxiliary function"""
def _compute(attrs, inputs, out_info):
axis = attrs.get_int_tuple("axis")
keepdims = attrs.get_bool("keepdims")
if axis:
return f(inputs[0], axis=axis, keepdims=keepdims)
return f(inputs[0], keepdims=keepdims)
return _compute
# sum
reg.register_compute("sum", _compute_reduce(topi.sum))
reg.register_pattern("sum", OpPattern.COMM_REDUCE)
reg.register_schedule("sum", _fschedule_reduce)
# max
reg.register_compute("max", _compute_reduce(topi.max))
reg.register_pattern("max", OpPattern.COMM_REDUCE)
reg.register_schedule("max", _fschedule_reduce)
# min
reg.register_compute("min", _compute_reduce(topi.min))
reg.register_pattern("min", OpPattern.COMM_REDUCE)
reg.register_schedule("min", _fschedule_reduce)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment