Skip to content

Instantly share code, notes, and snippets.

@ajtulloch
Created June 11, 2019 22:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ajtulloch/8a2d68deec59045d471b3debdf5aeefc to your computer and use it in GitHub Desktop.
Save ajtulloch/8a2d68deec59045d471b3debdf5aeefc to your computer and use it in GitHub Desktop.
diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc
index 3a2e54c8..4059dc3a 100644
--- a/src/relay/pass/quantize.cc
+++ b/src/relay/pass/quantize.cc
@@ -340,18 +340,9 @@ Expr MulRealize(const Call& ref_call,
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
-
DataType dtype = cfg->dtype_activation;
- if (lhs->dtype == Float(32)) {
- ldata = Cast(ldata, dtype);
- } else {
- CHECK_EQ(lhs->dtype, dtype);
- }
- if (rhs->dtype == Float(32)) {
- rdata = Cast(rdata, dtype);
- } else {
- CHECK_EQ(rhs->dtype, dtype);
- }
+ ldata = Cast(ldata, dtype);
+ rdata = Cast(rdata, dtype);
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment