Created
June 25, 2021 19:05
-
-
Save dreiss/e3ace30b8316b364dc997547d579f116 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git i/test/test_nnapi.py w/test/test_nnapi.py | |
index 23165c1d78..4328a73112 100644 | |
--- i/test/test_nnapi.py | |
+++ w/test/test_nnapi.py | |
@@ -68,6 +68,7 @@ class TestNNAPI(TestCase): | |
# Too many mismatches. Re-run the check with no tolerance | |
# to get a nice message. | |
self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) | |
+ return eager_output, nnapi_output | |
def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): | |
torch.manual_seed(29) | |
@@ -220,6 +221,26 @@ class TestNNAPI(TestCase): | |
torch.tensor([[3.0, 4.0], [5.0, 6.0]]), | |
]) | |
+ def test_pointwise_binary_const(self): | |
+ class ArgPlusConst(torch.nn.Module): | |
+ def forward(self, arg): | |
+ return arg + torch.ones(1, 1, 1, 1) | |
+ | |
+ class ConstPlusArg(torch.nn.Module): | |
+ def forward(self, arg): | |
+ return torch.ones(1, 1, 1, 1) + arg | |
+ | |
+ arg_cont = torch.randn(2, 4, 6, 6) | |
+ arg_nhwc = nhwc(arg_cont) | |
+ | |
+ for mod_class in [ArgPlusConst, ConstPlusArg]: | |
+ for use_nhwc in [False, True]: | |
+ with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): | |
+ arg = arg_nhwc if use_nhwc else arg_cont | |
+ eager_out, nnapi_out = self.check(mod_class(), arg) | |
+ self.assertTrue(nnapi_out.is_contiguous(memory_format=( | |
+ torch.channels_last if use_nhwc else torch.contiguous_format))) | |
+ | |
def test_hardtanh(self): | |
inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) | |
self.check(torch.nn.Hardtanh(), inp) | |
diff --git i/torch/backends/_nnapi/serializer.py w/torch/backends/_nnapi/serializer.py | |
index 80a52d5c16..88594d1019 100644 | |
--- i/torch/backends/_nnapi/serializer.py | |
+++ w/torch/backends/_nnapi/serializer.py | |
@@ -529,9 +529,16 @@ class _NnapiSerializer(object): | |
def compute_operand_shape(self, op_id, dim, expr): | |
self.flexible_shape_computation_lines.append(f"{flex_name(op_id, dim)} = {expr}") | |
+ def coerce_to_contiguous(self, in_id, oper): | |
+ assert oper.dim_order == DimOrder.UNKNOWN_CONSTANT | |
+ out_oper = oper._replace(dim_order=DimOrder.PRESUMED_CONTIGUOUS) | |
+ #out_id = self.add_anonymous_tensor_operand(out_oper) | |
+ # XXX this si gross | |
+ return in_id, out_oper | |
+ | |
def transpose_to_nhwc(self, in_id, oper): | |
- if oper.shape[2:] != (1, 1): | |
- raise Exception("Automatic transpose only supported for H,W == 1,1") | |
+ if oper.dim_order != DimOrder.UNKNOWN_CONSTANT and oper.shape[2:] != (1, 1): | |
+ raise Exception("Automatic transpose only supported for constants or H,W == 1,1") | |
out_oper = oper._replace(dim_order=DimOrder.CHANNELS_LAST) | |
@@ -551,13 +558,27 @@ class _NnapiSerializer(object): | |
if in0_oper.dim_order == in1_oper.dim_order: | |
return in0_id, in0_oper, in1_id, in1_oper | |
- # Assume NHWC is preferred if there is a mismatch. | |
orders = (in0_oper.dim_order, in1_oper.dim_order) | |
+ | |
+ # Assume NHWC is preferred if there is a mismatch. | |
if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.CHANNELS_LAST): | |
return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) | |
if orders == (DimOrder.CHANNELS_LAST, DimOrder.PRESUMED_CONTIGUOUS): | |
return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) | |
+ # Constants can be treated as contiguous. | |
+ if orders == (DimOrder.UNKNOWN_CONSTANT, DimOrder.PRESUMED_CONTIGUOUS): | |
+ return self.coerce_to_contiguous(in0_id, in0_oper) + (in1_id, in1_oper) | |
+ if orders == (DimOrder.PRESUMED_CONTIGUOUS, DimOrder.UNKNOWN_CONSTANT): | |
+ return (in0_id, in0_oper) + self.coerce_to_contiguous(in1_id, in1_oper) | |
+ | |
+ # Constants can be transposed to NHWC. | |
+ # TODO: Do this outside of the model execution. | |
+ if orders == (DimOrder.UNKNOWN_CONSTANT, DimOrder.CHANNELS_LAST): | |
+ return self.transpose_to_nhwc(in0_id, in0_oper) + (in1_id, in1_oper) | |
+ if orders == (DimOrder.CHANNELS_LAST, DimOrder.UNKNOWN_CONSTANT): | |
+ return (in0_id, in0_oper) + self.transpose_to_nhwc(in1_id, in1_oper) | |
+ | |
raise Exception( | |
"Automatic transpose not supported for dim_orders: %r, %r" % | |
(in0_oper.dim_order, in1_oper.dim_order)) | |
@@ -1069,9 +1090,8 @@ class _NnapiSerializer(object): | |
assert node.inputsAt(0).type().kind() == "TensorType" | |
assert node.inputsAt(1).type().kind() == "TensorType" | |
- # TODO: Should support constant as either operand. | |
- in0_id, in0_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(0)) | |
- in1_id, in1_oper = self.get_tensor_operand_by_jitval_fixed_size(node.inputsAt(1)) | |
+ in0_id, in0_oper = self.get_tensor_operand_or_constant(node.inputsAt(0)) | |
+ in1_id, in1_oper = self.get_tensor_operand_or_constant(node.inputsAt(1)) | |
assert in0_oper.op_type == in1_oper.op_type | |
in0_id, in0_oper, in1_id, in1_oper = self.transpose_for_broadcast( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment