Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save emfomenk/7a28ff7332f3987179d882a1786b0c4c to your computer and use it in GitHub Desktop.
Save emfomenk/7a28ff7332f3987179d882a1786b0c4c to your computer and use it in GitHub Desktop.
fix for mkl-dnn/#606: cpu: concat: include padded dims when computing the physical tensor size
From ed11371ba3c9e546aeb72cf6636dd36ccabbaf5f Mon Sep 17 00:00:00 2001
From: "Fomenko, Evarist M" <evarist.m.fomenko@intel.com>
Date: Fri, 29 Nov 2019 08:07:27 +0000
Subject: [PATCH] cpu: concat: include padded dims when computing the physical
tensor size
this closes #606
---
src/cpu/simple_concat.cpp | 2 +-
src/cpu/simple_concat.hpp | 2 +-
tests/benchdnn/inputs/concat/test_concat_all | 10 ++++++++++
3 files changed, 12 insertions(+), 2 deletions(-)
diff --git a/src/cpu/simple_concat.cpp b/src/cpu/simple_concat.cpp
index 10dd16f69..d3d4a67e7 100644
--- a/src/cpu/simple_concat.cpp
+++ b/src/cpu/simple_concat.cpp
@@ -62,7 +62,7 @@ status_t simple_concat_t<data_type>::execute(const exec_ctx_t &ctx) const {
dims_t phys_dims;
for (size_t i = 0; i < sizeof(phys_dims) / sizeof(phys_dims[0]); i++)
phys_dims[i] = (i < (size_t)perm[concat_dim])
- ? o_d.dims()[iperm[i]] / pd()->blocks_[iperm[i]]
+ ? o_d.padded_dims()[iperm[i]] / pd()->blocks_[iperm[i]]
: 1;
if (perm[concat_dim] == 0) {
diff --git a/src/cpu/simple_concat.hpp b/src/cpu/simple_concat.hpp
index 236c888d9..8e01b9626 100644
--- a/src/cpu/simple_concat.hpp
+++ b/src/cpu/simple_concat.hpp
@@ -109,7 +109,7 @@ struct simple_concat_t : public primitive_impl_t {
dim_t nelems = 1;
for (int i = perm_[concat_dim()]; i < ndims; i++)
- nelems *= data_d.dims()[iperm_[i]] / blocks_[iperm_[i]];
+ nelems *= data_d.padded_dims()[iperm_[i]] / blocks_[iperm_[i]];
for (int i = 0; i < ndims; i++)
nelems *= blocks_[i];
diff --git a/tests/benchdnn/inputs/concat/test_concat_all b/tests/benchdnn/inputs/concat/test_concat_all
index dfe73e99d..d299b11d8 100644
--- a/tests/benchdnn/inputs/concat/test_concat_all
+++ b/tests/benchdnn/inputs/concat/test_concat_all
@@ -25,5 +25,15 @@
6x8x3x4:6x1x3x4:6x0x3x4:6x3x3x4
6x0x3x4:6x3x3x4:6x5x3x4:6x5x3x4
+# sizes are not multiple of blocks + (non-blocking) axis
+--sdt=f32,s32,s8
+--ddt=f32,s32,s8
+--dtag=undef,nhwc,nChw16c
+--axis=2
+--stag=nChw16c:nChw16c
+6x5x3x4:6x5x3x4
+6x25x3x4:6x25x3x4
+6x23x0x4:6x23x3x4
+
# bf16
--batch=test_concat_bfloat16
--
2.22.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment