Skip to content

Instantly share code, notes, and snippets.

@yzhliu
Last active April 11, 2019 01:34
Show Gist options
  • Save yzhliu/0d2cdfc9fa92127b81a1298d5bec55a0 to your computer and use it in GitHub Desktop.
Save yzhliu/0d2cdfc9fa92127b81a1298d5bec55a0 to your computer and use it in GitHub Desktop.
diff --git a/src/arithmetic/const_fold.h b/src/arithmetic/const_fold.h
index fbf8fe7e..1c397f40 100644
--- a/src/arithmetic/const_fold.h
+++ b/src/arithmetic/const_fold.h
@@ -101,33 +101,28 @@ inline bool IsIndexType(const Type& type) {
// specialization of constant folders.
template<>
inline Expr TryConstFold<ir::Add>(Expr a, Expr b) {
- TVM_ARITH_CONST_PROPAGATION({
+ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
- if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value);
- if (fa && fa->value == 0) return b;
- if (fb && fb->value == 0) return a;
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Sub>(Expr a, Expr b) {
- TVM_ARITH_CONST_PROPAGATION({
+ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a;
- if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value);
- if (fb && fb->value == 0) return a;
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
- TVM_ARITH_CONST_PROPAGATION({
+ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, pa->value * pb->value);
if (pa) {
@@ -138,22 +133,13 @@ inline Expr TryConstFold<ir::Mul>(Expr a, Expr b) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
- if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value);
- if (fa) {
- if (fa->value == 1) return b;
- if (fa->value == 0) return a;
- }
- if (fb) {
- if (fb->value == 1) return a;
- if (fb->value == 0) return b;
- }
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
- TVM_ARITH_CONST_PROPAGATION({
+ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
@@ -167,14 +153,6 @@ inline Expr TryConstFold<ir::Div>(Expr a, Expr b) {
if (pb->value == 1) return a;
CHECK_NE(pb->value, 0) << "Divide by zero";
}
- if (fa && fb && fb->value != 0) {
- return FloatImm::make(rtype, fa->value / fb->value);
- }
- if (fa && fa->value == 0) return a;
- if (fb) {
- if (fb->value == 1) return a;
- CHECK_NE(fb->value, 0) << "Divide by zero";
- }
});
return Expr();
}
@@ -201,20 +179,18 @@ inline Expr TryConstFold<ir::Mod>(Expr a, Expr b) {
template<>
inline Expr TryConstFold<ir::Min>(Expr a, Expr b) {
- TVM_ARITH_CONST_PROPAGATION({
+ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value));
- if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value));
});
return Expr();
}
template<>
inline Expr TryConstFold<ir::Max>(Expr a, Expr b) {
- TVM_ARITH_CONST_PROPAGATION({
+ TVM_INDEX_CONST_PROPAGATION({
const Type& rtype = a.type();
if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value));
- if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value));
});
return Expr();
}
diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc
index 4504ee23..753ad6a8 100644
--- a/src/lang/expr_operator.cc
+++ b/src/lang/expr_operator.cc
@@ -106,14 +106,11 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) {
Expr cast(const Type& t, Expr value) {
using ir::IntImm;
- using ir::FloatImm;
if (value.type() == t) return value;
// const fold IntImm as they are used in index computations
if (t.lanes() == 1) {
if (const IntImm* op = value.as<IntImm>()) {
return make_const(t, op->value);
- } else if (const FloatImm* op = value.as<FloatImm>()) {
- return make_const(t, op->value);
}
return ir::Cast::make(t, value);
} else {
@@ -123,8 +120,6 @@ Expr cast(const Type& t, Expr value) {
if (value.type() != vtype) {
if (const IntImm* op = value.as<IntImm>()) {
value = make_const(vtype, op->value);
- } else if (const FloatImm* op = value.as<FloatImm>()) {
- value = make_const(vtype, op->value);
} else {
value = ir::Cast::make(vtype, value);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment