Skip to content

Instantly share code, notes, and snippets.

@dreiss
Created June 25, 2021 19:06
Show Gist options
  • Save dreiss/11461ce2b47c53378b5744372479e8ac to your computer and use it in GitHub Desktop.
Save dreiss/11461ce2b47c53378b5744372479e8ac to your computer and use it in GitHub Desktop.
diff --git i/test/test_nnapi.py w/test/test_nnapi.py
index 23165c1d78..4328a73112 100644
--- i/test/test_nnapi.py
+++ w/test/test_nnapi.py
@@ -68,6 +68,7 @@ class TestNNAPI(TestCase):
# Too many mismatches. Re-run the check with no tolerance
# to get a nice message.
self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0)
+ return eager_output, nnapi_output
def float_and_quant_and_nhwc(self, inp_float, scale, zero_point):
torch.manual_seed(29)
@@ -220,6 +221,26 @@ class TestNNAPI(TestCase):
torch.tensor([[3.0, 4.0], [5.0, 6.0]]),
])
+ def test_pointwise_binary_const(self):
+ class ArgPlusConst(torch.nn.Module):
+ def forward(self, arg):
+ return arg + torch.ones(1, 1, 1, 1)
+
+ class ConstPlusArg(torch.nn.Module):
+ def forward(self, arg):
+ return torch.ones(1, 1, 1, 1) + arg
+
+ arg_cont = torch.randn(2, 4, 6, 6)
+ arg_nhwc = nhwc(arg_cont)
+
+ for mod_class in [ArgPlusConst, ConstPlusArg]:
+ for use_nhwc in [False, True]:
+ with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc):
+ arg = arg_nhwc if use_nhwc else arg_cont
+ eager_out, nnapi_out = self.check(mod_class(), arg)
+ self.assertTrue(nnapi_out.is_contiguous(memory_format=(
+ torch.channels_last if use_nhwc else torch.contiguous_format)))
+
def test_hardtanh(self):
inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0])
self.check(torch.nn.Hardtanh(), inp)
diff --git i/torch/backends/_nnapi/serializer.py w/torch/backends/_nnapi/serializer.py
index 80a52d5c16..88594d1019 100644
--- i/torch/backends/_nnapi/serializer.py
+++ w/torch/backends/_nnapi/serializer.py
@@ -529,9 +529,16 @@ class _NnapiSerializer(object):
def compute_operand_shape(self, op_id, dim, expr):
self.flexible_shape_computation_lines.append(f"{flex_name(op_id, dim)} = {expr}")
+ def coerce_to_contiguous(self, in_id, oper):
+ assert oper.dim_order == DimOrder.UNKNOWN_CONSTANT
+ out_oper = oper._replace(dim_order=DimOrder.PRESUMED_CONTIGUOUS)
+ #out_id = self.add_anonymous_tensor_operand(out_oper)
+ # XXX this si gross
+ return in_id, out_oper
+
def transpose_to_nhwc(self, in_id, oper):
- if oper.shape[2:] != (1, 1):
- raise Exception("Automatic transpose only supported for H,W == 1,1")
+ if oper.dim_order != DimOrder.UNKNOWN_CONSTANT and oper.shape[2:] != (1, 1):
+ raise Exception("Automatic transpose only supported for constants or H,W == 1,1")
out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST)
@@ -551,13 +558,27 @@ class _NnapiSerializer(object):
if in0_oper.dim_order == in1_oper.dim_order:
return in0_id, in0_oper, in1_id, in1_oper
- # Assume NHWC is preferred if there is a mismatch.
orders = (in0_oper.dim_order, in1_oper.dim_order)
+
+ # Assume NHWC is preferred if there is a mismatch.
if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST):
return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS):
return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+ # Constants can be treated as contiguous.
+ if orders == (DimOrder.UNKNOWN_CONSTANT, DimOrder.PRESUMED_CONTIGUOUS):
+ return self.coerce_to_contiguous(in0_id, in0_oper) + (in1_id, in1_oper)
+ if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.UNKNOWN_CONSTANT):
+ return (in0_id, in0_oper) + self.coerce_to_contiguous(in1_id, in1_oper)
+
+ # Constants can be transposed to NHWC.
+ # TODO: Do this outside of the model execution.
+ if orders == (DimOrder.UNKNOWN_CONSTANT, DimOrder.CHANNELS_LAST):
+ return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper)
+ if orders == (DimOrder.CHANNELS_LAST, DimOrder.UNKNOWN_CONSTANT):
+ return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper)
+
raise Exception(
"Automatic transpose not supported for dim_orders: %r, %r" %
(in0_oper.dim_order, in1_oper.dim_order))
@@ -1069,9 +1090,8 @@ class _NnapiSerializer(object):
assert node.inputsAt(0).type().kind() == "TensorType"
assert node.inputsAt(1).type().kind() == "TensorType"
- # TODO: Should support constant as either operand.
- in0_id, in0_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0))
- in1_id, in1_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(1))
+ in0_id, in0_oper = self.get_tensor_operand_or_constant(node.inputsAt(0))
+ in1_id, in1_oper = self.get_tensor_operand_or_constant(node.inputsAt(1))
assert in0_oper.op_type == in1_oper.op_type
in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment