Skip to content

Instantly share code, notes, and snippets.

@emfomenk
Created November 28, 2018 11:37
Show Gist options
  • Save emfomenk/1c5b78333d0eede05b796f85ccf48192 to your computer and use it in GitHub Desktop.
Save emfomenk/1c5b78333d0eede05b796f85ccf48192 to your computer and use it in GitHub Desktop.
Adding iohw and giohw formats
From b401946179eba2675b35562f8457bc1424b5878d Mon Sep 17 00:00:00 2001
From: "Fomenko, Evarist M" <evarist.m.fomenko@intel.com>
Date: Wed, 28 Nov 2018 03:13:53 -0800
Subject: [PATCH] api: add iohw and giohw formats
These formats are useful for user weights format in deconvolution.
For instance in PyTorch weights for convolution are kept in `oihw`
and `goihw` formats for groups = 1 and groups > 1 case respectively.
But weights for deconvolution are kept in `iohw` and `giohw` formats
for groups = 1 and groups > 1 respecitively.
See #352 for more details.
---
include/mkldnn.hpp | 2 ++
include/mkldnn_types.h | 6 ++++++
src/common/c_types_map.hpp | 2 ++
src/common/format_traits.hpp | 2 ++
src/common/memory_desc_wrapper.cpp | 16 ++++++++++++++++
src/common/mkldnn_debug.cpp | 2 ++
src/common/type_helpers.hpp | 2 ++
tests/benchdnn/dnn_types.cpp | 1 +
tests/benchdnn/mkldnn_debug.cpp | 2 ++
tests/gtests/mkldnn_test_common.hpp | 2 ++
tests/gtests/test_reorder.cpp | 6 +++++-
11 files changed, 42 insertions(+), 1 deletion(-)
diff --git a/include/mkldnn.hpp b/include/mkldnn.hpp
index 92f176bc..d1b9ba7e 100644
--- a/include/mkldnn.hpp
+++ b/include/mkldnn.hpp
@@ -624,6 +624,7 @@ struct memory: public primitive {
oihw = mkldnn_oihw,
ihwo = mkldnn_ihwo,
hwio = mkldnn_hwio,
+ iohw = mkldnn_iohw,
hwio_s8s8 = mkldnn_hwio_s8s8,
dhwio = mkldnn_dhwio,
oidhw = mkldnn_oidhw,
@@ -666,6 +667,7 @@ struct memory: public primitive {
gOIw8o16i2o = mkldnn_gOIw8o16i2o,
goihw = mkldnn_goihw,
hwigo = mkldnn_hwigo,
+ giohw = mkldnn_giohw,
hwigo_s8s8 = mkldnn_hwigo_s8s8,
gOIdhw8i8o = mkldnn_gOIdhw8i8o,
gOIdhw8o8i = mkldnn_gOIdhw8o8i,
diff --git a/include/mkldnn_types.h b/include/mkldnn_types.h
index a688cea2..87a7a083 100644
--- a/include/mkldnn_types.h
+++ b/include/mkldnn_types.h
@@ -188,6 +188,9 @@ typedef enum {
/** 4D weights tensor with physical layout @c ihwo.
* Logical dimensions come in the order: (o, i, h, w) */
mkldnn_ihwo,
+ /** 4D weights tensor with physical layout @c iohw.
+ * Logical dimensions come in the order: (o, i, h, w) */
+ mkldnn_iohw,
/** 5D weights tensor with physical layout @c iodhw, used in Caffe.
* Logical dimensions come in the order: (o, i, d, h, w) */
mkldnn_oidhw,
@@ -205,6 +208,9 @@ typedef enum {
* used in TensorFlow.
* Logical dimensions come in the order: (g, o, i, h, w) */
mkldnn_hwigo,
+ /** 5D grouped weights tensor with the physical layout @c giohw.
+ * Logical dimensions come in the order: (g, o, i, h, w) */
+ mkldnn_giohw,
/** 6D grouped weights tensor with the physical layout @c goidhw,
* used in Caffe.
* Logical dimensions come in the order: (g, o, i, d, h, w) */
diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp
index 6c00497a..bc32951c 100644
--- a/src/common/c_types_map.hpp
+++ b/src/common/c_types_map.hpp
@@ -139,6 +139,7 @@ namespace memory_format {
const memory_format_t oihw = mkldnn_oihw;
const memory_format_t ihwo = mkldnn_ihwo;
const memory_format_t hwio = mkldnn_hwio;
+ const memory_format_t iohw = mkldnn_iohw;
const memory_format_t hwio_s8s8 = mkldnn_hwio_s8s8;
const memory_format_t dhwio = mkldnn_dhwio;
const memory_format_t oidhw = mkldnn_oidhw;
@@ -179,6 +180,7 @@ namespace memory_format {
const memory_format_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
const memory_format_t goihw = mkldnn_goihw;
const memory_format_t hwigo = mkldnn_hwigo;
+ const memory_format_t giohw = mkldnn_giohw;
const memory_format_t hwigo_s8s8 = mkldnn_hwigo_s8s8;
const memory_format_t gOIhw8i8o = mkldnn_gOIhw8i8o;
const memory_format_t gOIhw16i16o = mkldnn_gOIhw16i16o;
diff --git a/src/common/format_traits.hpp b/src/common/format_traits.hpp
index 329bd43b..6507c684 100644
--- a/src/common/format_traits.hpp
+++ b/src/common/format_traits.hpp
@@ -125,6 +125,7 @@ DECL_TRAITS(OIw8o16i2o, wei, _8o16i2o, 3, 1);
DECL_TRAITS(oihw, wei, _, 4, 2);
DECL_TRAITS(ihwo, wei, _, 4, 2);
DECL_TRAITS(hwio, wei, _, 4, 2);
+DECL_TRAITS(iohw, wei, _, 4, 2);
DECL_TRAITS(hwio_s8s8, wei, _, 4, 2);
DECL_TRAITS(oIhw8i, wei, _8i, 4, 2);
DECL_TRAITS(oIhw16i, wei, _16i, 4, 2);
@@ -171,6 +172,7 @@ DECL_TRAITS(gOIw8o16i2o, gwei, _8o16i2o, 4, 1);
/* gwei: 5D */
DECL_TRAITS(goihw, gwei, _, 5, 2);
DECL_TRAITS(hwigo, gwei, _, 5, 2);
+DECL_TRAITS(giohw, gwei, _, 5, 2);
DECL_TRAITS(hwigo_s8s8, gwei, _, 5, 2);
DECL_TRAITS(gOIhw8i8o, gwei, _8i8o, 5, 2);
DECL_TRAITS(gOIhw16i16o, gwei, _16i16o, 5, 2);
diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp
index ff63189c..991be186 100644
--- a/src/common/memory_desc_wrapper.cpp
+++ b/src/common/memory_desc_wrapper.cpp
@@ -384,6 +384,13 @@ status_t fill_hwio(memory_desc_t &md) {
return fill_nonblocked(md, perm);
}
+status_t fill_iohw(memory_desc_t &md) {
+ if (md.ndims != 4) return invalid_arguments;
+
+ const int perm[4] = {1, 0, 2, 3};
+ return fill_nonblocked(md, perm);
+}
+
status_t fill_dhwio(memory_desc_t &md) {
if (md.ndims != 5) return invalid_arguments;
@@ -702,6 +709,13 @@ status_t fill_hwigo(memory_desc_t &md) {
return fill_nonblocked(md, perm);
}
+status_t fill_giohw(memory_desc_t &md) {
+ if (md.ndims != 5) return invalid_arguments;
+
+ const int perm[5] = {0, 2, 1, 3, 4};
+ return fill_nonblocked(md, perm);
+}
+
status_t fill_gOIhw8i8o(memory_desc_t &md) {
if (md.ndims != 5) return invalid_arguments;
@@ -987,6 +1001,7 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
case oihw: return fill_oihw(memory_desc);
case ihwo: return fill_ihwo(memory_desc);
case hwio: return fill_hwio(memory_desc);
+ case iohw: return fill_iohw(memory_desc);
case hwio_s8s8: return fill_hwio(memory_desc);
case dhwio: return fill_dhwio(memory_desc);
case OIhw8i8o: return fill_OIhw8i8o(memory_desc);
@@ -1015,6 +1030,7 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc)
case gIOw16o16i: return fill_gIOw16o16i(memory_desc);
case goihw: return fill_goihw(memory_desc);
case hwigo: return fill_hwigo(memory_desc);
+ case giohw: return fill_giohw(memory_desc);
case hwigo_s8s8: return fill_hwigo(memory_desc);
case gOIhw8i8o: return fill_gOIhw8i8o(memory_desc);
case gOIhw16i16o: return fill_gOIhw16i16o(memory_desc);
diff --git a/src/common/mkldnn_debug.cpp b/src/common/mkldnn_debug.cpp
index 6157d857..cb61ace7 100644
--- a/src/common/mkldnn_debug.cpp
+++ b/src/common/mkldnn_debug.cpp
@@ -72,6 +72,7 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) {
if (v == mkldnn_wio) return "wio";
if (v == mkldnn_oihw) return "oihw";
if (v == mkldnn_hwio) return "hwio";
+ if (v == mkldnn_iohw) return "iohw";
if (v == mkldnn_hwio_s8s8) return "hwio_s8s8";
if (v == mkldnn_ihwo) return "ihwo";
if (v == mkldnn_oidhw) return "oidhw";
@@ -79,6 +80,7 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) {
if (v == mkldnn_goiw) return "goiw";
if (v == mkldnn_goihw) return "goihw";
if (v == mkldnn_hwigo) return "hwigo";
+ if (v == mkldnn_giohw) return "giohw";
if (v == mkldnn_hwigo_s8s8) return "hwigo_s8s8";
if (v == mkldnn_goidhw) return "goidhw";
if (v == mkldnn_ntc) return "ntc";
diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp
index 1a8f64a5..0245ff78 100644
--- a/src/common/type_helpers.hpp
+++ b/src/common/type_helpers.hpp
@@ -122,6 +122,7 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
oihw,
ihwo,
hwio,
+ iohw,
hwio_s8s8,
dhwio,
oidhw,
@@ -162,6 +163,7 @@ inline memory_format_t format_normalize(const memory_format_t fmt) {
gIOw16o16i,
goihw,
hwigo,
+ giohw,
hwigo_s8s8,
gOIhw8i8o,
gOIhw16i16o,
diff --git a/tests/benchdnn/dnn_types.cpp b/tests/benchdnn/dnn_types.cpp
index e0c3ca44..e721ac9d 100644
--- a/tests/benchdnn/dnn_types.cpp
+++ b/tests/benchdnn/dnn_types.cpp
@@ -102,6 +102,7 @@ data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt) {
case mkldnn_gOIw8i16o2i:
case mkldnn_goihw:
case mkldnn_hwigo:
+ case mkldnn_giohw:
case mkldnn_hwigo_s8s8:
case mkldnn_gOIhw8i8o:
case mkldnn_gOIhw16i16o:
diff --git a/tests/benchdnn/mkldnn_debug.cpp b/tests/benchdnn/mkldnn_debug.cpp
index decf41bf..840c556b 100644
--- a/tests/benchdnn/mkldnn_debug.cpp
+++ b/tests/benchdnn/mkldnn_debug.cpp
@@ -96,6 +96,7 @@ mkldnn_memory_format_t str2fmt(const char *str) {
CASE(oihw);
CASE(ihwo);
CASE(hwio);
+ CASE(iohw);
CASE(hwio_s8s8);
CASE(dhwio);
CASE(OIhw8i8o);
@@ -114,6 +115,7 @@ mkldnn_memory_format_t str2fmt(const char *str) {
CASE(goiw);
CASE(goihw);
CASE(hwigo);
+ CASE(giohw);
CASE(hwigo_s8s8);
CASE(goiw);
CASE(gOIw16i16o);
diff --git a/tests/gtests/mkldnn_test_common.hpp b/tests/gtests/mkldnn_test_common.hpp
index f2900833..de69dbcb 100644
--- a/tests/gtests/mkldnn_test_common.hpp
+++ b/tests/gtests/mkldnn_test_common.hpp
@@ -237,6 +237,7 @@ inline mkldnn::memory::desc create_md(mkldnn::memory::dims dims,
case f::nChw16c:
case f::oihw:
case f::hwio:
+ case f::iohw:
case f::oIhw8i:
case f::oIhw16i:
case f::OIhw8i8o:
@@ -258,6 +259,7 @@ inline mkldnn::memory::desc create_md(mkldnn::memory::dims dims,
case f::oidhw:
case f::goihw:
case f::hwigo:
+ case f::giohw:
case f::oIdhw8i:
case f::oIdhw16i:
case f::OIdhw8i8o:
diff --git a/tests/gtests/test_reorder.cpp b/tests/gtests/test_reorder.cpp
index e182e91c..89bb4bf3 100644
--- a/tests/gtests/test_reorder.cpp
+++ b/tests/gtests/test_reorder.cpp
@@ -333,7 +333,11 @@ TEST_P(reorder_simple_test_weights_f32_f32_1, TestsReorder) { }
INSTANTIATE_TEST_CASE_P(TestReorder, reorder_simple_test_weights_f32_f32_1,
::testing::Values(
cfg_f32{eng::cpu, fmt::goihw, fmt::Goihw16g, {32, 32, 32, 3, 3}},
- cfg_f32{eng::cpu, fmt::Goihw16g, fmt::goihw, {32, 32, 32, 3, 3}}
+ cfg_f32{eng::cpu, fmt::Goihw16g, fmt::goihw, {32, 32, 32, 3, 3}},
+ cfg_f32{eng::cpu, fmt::oihw, fmt::iohw, {32, 32, 3, 3}},
+ cfg_f32{eng::cpu, fmt::iohw, fmt::oihw, {32, 32, 3, 3}},
+ cfg_f32{eng::cpu, fmt::goihw, fmt::giohw, {2, 32, 32, 3, 3}},
+ cfg_f32{eng::cpu, fmt::giohw, fmt::goihw, {2, 32, 32, 3, 3}}
)
);
--
2.14.5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment