Created
November 29, 2019 08:12
-
-
Save emfomenk/7a28ff7332f3987179d882a1786b0c4c to your computer and use it in GitHub Desktop.
fix for mkl-dnn/#606: cpu: concat: include padded dims when computing the physical tensor size
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From ed11371ba3c9e546aeb72cf6636dd36ccabbaf5f Mon Sep 17 00:00:00 2001 | |
From: "Fomenko, Evarist M" <evarist.m.fomenko@intel.com> | |
Date: Fri, 29 Nov 2019 08:07:27 +0000 | |
Subject: [PATCH] cpu: concat: include padded dims when computing the physical | |
tensor size | |
this closes #606 | |
--- | |
src/cpu/simple_concat.cpp | 2 +- | |
src/cpu/simple_concat.hpp | 2 +- | |
tests/benchdnn/inputs/concat/test_concat_all | 10 ++++++++++ | |
3 files changed, 12 insertions(+), 2 deletions(-) | |
diff --git a/src/cpu/simple_concat.cpp b/src/cpu/simple_concat.cpp | |
index 10dd16f69..d3d4a67e7 100644 | |
--- a/src/cpu/simple_concat.cpp | |
+++ b/src/cpu/simple_concat.cpp | |
@@ -62,7 +62,7 @@ status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const { | |
dims_t phys_dims; | |
for (size_t i = 0; i < sizeof(phys_dims) / sizeof(phys_dims[0]); i++) | |
phys_dims[i] = (i < (size_t)perm[concat_dim]) | |
- ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]] | |
+ ? o_d.padded_dims()[iperm[i]] / pd()->blocks_[iperm[i]] | |
: 1; | |
if (perm[concat_dim] == 0) { | |
diff --git a/src/cpu/simple_concat.hpp b/src/cpu/simple_concat.hpp | |
index 236c888d9..8e01b9626 100644 | |
--- a/src/cpu/simple_concat.hpp | |
+++ b/src/cpu/simple_concat.hpp | |
@@ -109,7 +109,7 @@ struct simple_concat_t : public primitive_impl_t { | |
dim_t nelems = 1; | |
for (int i = perm_[concat_dim()]; i < ndims; i++) | |
- nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]]; | |
+ nelems *= data_d.padded_dims()[iperm_[i]] / blocks_[iperm_[i]]; | |
for (int i = 0; i < ndims; i++) | |
nelems *= blocks_[i]; | |
diff --git a/tests/benchdnn/inputs/concat/test_concat_all b/tests/benchdnn/inputs/concat/test_concat_all | |
index dfe73e99d..d299b11d8 100644 | |
--- a/tests/benchdnn/inputs/concat/test_concat_all | |
+++ b/tests/benchdnn/inputs/concat/test_concat_all | |
@@ -25,5 +25,15 @@ | |
6x8x3x4:6x1x3x4:6x0x3x4:6x3x3x4 | |
6x0x3x4:6x3x3x4:6x5x3x4:6x5x3x4 | |
+# sizes are not multiple of blocks + (non-blocking) axis | |
+--sdt=f32,s32,s8 | |
+--ddt=f32,s32,s8 | |
+--dtag=undef,nhwc,nChw16c | |
+--axis=2 | |
+--stag=nChw16c:nChw16c | |
+6x5x3x4:6x5x3x4 | |
+6x25x3x4:6x25x3x4 | |
+6x23x0x4:6x23x3x4 | |
+ | |
# bf16 | |
--batch=test_concat_bfloat16 | |
-- | |
2.22.0 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment