Created
November 28, 2018 11:37
-
-
Save emfomenk/1c5b78333d0eede05b796f85ccf48192 to your computer and use it in GitHub Desktop.
Adding iohw and giohw formats
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From b401946179eba2675b35562f8457bc1424b5878d Mon Sep 17 00:00:00 2001 | |
From: "Fomenko, Evarist M" <evarist.m.fomenko@intel.com> | |
Date: Wed, 28 Nov 2018 03:13:53 -0800 | |
Subject: [PATCH] api: add iohw and giohw formats | |
These formats are useful for user weights format in deconvolution. | |
For instance in PyTorch weights for convolution are kept in `oihw` | |
and `goihw` formats for groups = 1 and groups > 1 case respectively. | |
But weights for deconvolution are kept in `iohw` and `giohw` formats | |
for groups = 1 and groups > 1 respecitively. | |
See #352 for more details. | |
--- | |
include/mkldnn.hpp | 2 ++ | |
include/mkldnn_types.h | 6 ++++++ | |
src/common/c_types_map.hpp | 2 ++ | |
src/common/format_traits.hpp | 2 ++ | |
src/common/memory_desc_wrapper.cpp | 16 ++++++++++++++++ | |
src/common/mkldnn_debug.cpp | 2 ++ | |
src/common/type_helpers.hpp | 2 ++ | |
tests/benchdnn/dnn_types.cpp | 1 + | |
tests/benchdnn/mkldnn_debug.cpp | 2 ++ | |
tests/gtests/mkldnn_test_common.hpp | 2 ++ | |
tests/gtests/test_reorder.cpp | 6 +++++- | |
11 files changed, 42 insertions(+), 1 deletion(-) | |
diff --git a/include/mkldnn.hpp b/include/mkldnn.hpp | |
index 92f176bc..d1b9ba7e 100644 | |
--- a/include/mkldnn.hpp | |
+++ b/include/mkldnn.hpp | |
@@ -624,6 +624,7 @@ struct memory: public primitive { | |
oihw = mkldnn_oihw, | |
ihwo = mkldnn_ihwo, | |
hwio = mkldnn_hwio, | |
+ iohw = mkldnn_iohw, | |
hwio_s8s8 = mkldnn_hwio_s8s8, | |
dhwio = mkldnn_dhwio, | |
oidhw = mkldnn_oidhw, | |
@@ -666,6 +667,7 @@ struct memory: public primitive { | |
gOIw8o16i2o = mkldnn_gOIw8o16i2o, | |
goihw = mkldnn_goihw, | |
hwigo = mkldnn_hwigo, | |
+ giohw = mkldnn_giohw, | |
hwigo_s8s8 = mkldnn_hwigo_s8s8, | |
gOIdhw8i8o = mkldnn_gOIdhw8i8o, | |
gOIdhw8o8i = mkldnn_gOIdhw8o8i, | |
diff --git a/include/mkldnn_types.h b/include/mkldnn_types.h | |
index a688cea2..87a7a083 100644 | |
--- a/include/mkldnn_types.h | |
+++ b/include/mkldnn_types.h | |
@@ -188,6 +188,9 @@ typedef enum { | |
/** 4D weights tensor with physical layout @c ihwo. | |
* Logical dimensions come in the order: (o, i, h, w) */ | |
mkldnn_ihwo, | |
+ /** 4D weights tensor with physical layout @c iohw. | |
+ * Logical dimensions come in the order: (o, i, h, w) */ | |
+ mkldnn_iohw, | |
/** 5D weights tensor with physical layout @c iodhw, used in Caffe. | |
* Logical dimensions come in the order: (o, i, d, h, w) */ | |
mkldnn_oidhw, | |
@@ -205,6 +208,9 @@ typedef enum { | |
* used in TensorFlow. | |
* Logical dimensions come in the order: (g, o, i, h, w) */ | |
mkldnn_hwigo, | |
+ /** 5D grouped weights tensor with the physical layout @c giohw. | |
+ * Logical dimensions come in the order: (g, o, i, h, w) */ | |
+ mkldnn_giohw, | |
/** 6D grouped weights tensor with the physical layout @c goidhw, | |
* used in Caffe. | |
* Logical dimensions come in the order: (g, o, i, d, h, w) */ | |
diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp | |
index 6c00497a..bc32951c 100644 | |
--- a/src/common/c_types_map.hpp | |
+++ b/src/common/c_types_map.hpp | |
@@ -139,6 +139,7 @@ namespace memory_format { | |
const memory_format_t oihw = mkldnn_oihw; | |
const memory_format_t ihwo = mkldnn_ihwo; | |
const memory_format_t hwio = mkldnn_hwio; | |
+ const memory_format_t iohw = mkldnn_iohw; | |
const memory_format_t hwio_s8s8 = mkldnn_hwio_s8s8; | |
const memory_format_t dhwio = mkldnn_dhwio; | |
const memory_format_t oidhw = mkldnn_oidhw; | |
@@ -179,6 +180,7 @@ namespace memory_format { | |
const memory_format_t gOIw8o16i2o = mkldnn_gOIw8o16i2o; | |
const memory_format_t goihw = mkldnn_goihw; | |
const memory_format_t hwigo = mkldnn_hwigo; | |
+ const memory_format_t giohw = mkldnn_giohw; | |
const memory_format_t hwigo_s8s8 = mkldnn_hwigo_s8s8; | |
const memory_format_t gOIhw8i8o = mkldnn_gOIhw8i8o; | |
const memory_format_t gOIhw16i16o = mkldnn_gOIhw16i16o; | |
diff --git a/src/common/format_traits.hpp b/src/common/format_traits.hpp | |
index 329bd43b..6507c684 100644 | |
--- a/src/common/format_traits.hpp | |
+++ b/src/common/format_traits.hpp | |
@@ -125,6 +125,7 @@ DECL_TRAITS(OIw8o16i2o, wei, _8o16i2o, 3, 1); | |
DECL_TRAITS(oihw, wei, _, 4, 2); | |
DECL_TRAITS(ihwo, wei, _, 4, 2); | |
DECL_TRAITS(hwio, wei, _, 4, 2); | |
+DECL_TRAITS(iohw, wei, _, 4, 2); | |
DECL_TRAITS(hwio_s8s8, wei, _, 4, 2); | |
DECL_TRAITS(oIhw8i, wei, _8i, 4, 2); | |
DECL_TRAITS(oIhw16i, wei, _16i, 4, 2); | |
@@ -171,6 +172,7 @@ DECL_TRAITS(gOIw8o16i2o, gwei, _8o16i2o, 4, 1); | |
/* gwei: 5D */ | |
DECL_TRAITS(goihw, gwei, _, 5, 2); | |
DECL_TRAITS(hwigo, gwei, _, 5, 2); | |
+DECL_TRAITS(giohw, gwei, _, 5, 2); | |
DECL_TRAITS(hwigo_s8s8, gwei, _, 5, 2); | |
DECL_TRAITS(gOIhw8i8o, gwei, _8i8o, 5, 2); | |
DECL_TRAITS(gOIhw16i16o, gwei, _16i16o, 5, 2); | |
diff --git a/src/common/memory_desc_wrapper.cpp b/src/common/memory_desc_wrapper.cpp | |
index ff63189c..991be186 100644 | |
--- a/src/common/memory_desc_wrapper.cpp | |
+++ b/src/common/memory_desc_wrapper.cpp | |
@@ -384,6 +384,13 @@ status_t fill_hwio(memory_desc_t &md) { | |
return fill_nonblocked(md, perm); | |
} | |
+status_t fill_iohw(memory_desc_t &md) { | |
+ if (md.ndims != 4) return invalid_arguments; | |
+ | |
+ const int perm[4] = {1, 0, 2, 3}; | |
+ return fill_nonblocked(md, perm); | |
+} | |
+ | |
status_t fill_dhwio(memory_desc_t &md) { | |
if (md.ndims != 5) return invalid_arguments; | |
@@ -702,6 +709,13 @@ status_t fill_hwigo(memory_desc_t &md) { | |
return fill_nonblocked(md, perm); | |
} | |
+status_t fill_giohw(memory_desc_t &md) { | |
+ if (md.ndims != 5) return invalid_arguments; | |
+ | |
+ const int perm[5] = {0, 2, 1, 3, 4}; | |
+ return fill_nonblocked(md, perm); | |
+} | |
+ | |
status_t fill_gOIhw8i8o(memory_desc_t &md) { | |
if (md.ndims != 5) return invalid_arguments; | |
@@ -987,6 +1001,7 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc) | |
case oihw: return fill_oihw(memory_desc); | |
case ihwo: return fill_ihwo(memory_desc); | |
case hwio: return fill_hwio(memory_desc); | |
+ case iohw: return fill_iohw(memory_desc); | |
case hwio_s8s8: return fill_hwio(memory_desc); | |
case dhwio: return fill_dhwio(memory_desc); | |
case OIhw8i8o: return fill_OIhw8i8o(memory_desc); | |
@@ -1015,6 +1030,7 @@ status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc) | |
case gIOw16o16i: return fill_gIOw16o16i(memory_desc); | |
case goihw: return fill_goihw(memory_desc); | |
case hwigo: return fill_hwigo(memory_desc); | |
+ case giohw: return fill_giohw(memory_desc); | |
case hwigo_s8s8: return fill_hwigo(memory_desc); | |
case gOIhw8i8o: return fill_gOIhw8i8o(memory_desc); | |
case gOIhw16i16o: return fill_gOIhw16i16o(memory_desc); | |
diff --git a/src/common/mkldnn_debug.cpp b/src/common/mkldnn_debug.cpp | |
index 6157d857..cb61ace7 100644 | |
--- a/src/common/mkldnn_debug.cpp | |
+++ b/src/common/mkldnn_debug.cpp | |
@@ -72,6 +72,7 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) { | |
if (v == mkldnn_wio) return "wio"; | |
if (v == mkldnn_oihw) return "oihw"; | |
if (v == mkldnn_hwio) return "hwio"; | |
+ if (v == mkldnn_iohw) return "iohw"; | |
if (v == mkldnn_hwio_s8s8) return "hwio_s8s8"; | |
if (v == mkldnn_ihwo) return "ihwo"; | |
if (v == mkldnn_oidhw) return "oidhw"; | |
@@ -79,6 +80,7 @@ const char *mkldnn_fmt2str(mkldnn_memory_format_t v) { | |
if (v == mkldnn_goiw) return "goiw"; | |
if (v == mkldnn_goihw) return "goihw"; | |
if (v == mkldnn_hwigo) return "hwigo"; | |
+ if (v == mkldnn_giohw) return "giohw"; | |
if (v == mkldnn_hwigo_s8s8) return "hwigo_s8s8"; | |
if (v == mkldnn_goidhw) return "goidhw"; | |
if (v == mkldnn_ntc) return "ntc"; | |
diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp | |
index 1a8f64a5..0245ff78 100644 | |
--- a/src/common/type_helpers.hpp | |
+++ b/src/common/type_helpers.hpp | |
@@ -122,6 +122,7 @@ inline memory_format_t format_normalize(const memory_format_t fmt) { | |
oihw, | |
ihwo, | |
hwio, | |
+ iohw, | |
hwio_s8s8, | |
dhwio, | |
oidhw, | |
@@ -162,6 +163,7 @@ inline memory_format_t format_normalize(const memory_format_t fmt) { | |
gIOw16o16i, | |
goihw, | |
hwigo, | |
+ giohw, | |
hwigo_s8s8, | |
gOIhw8i8o, | |
gOIhw16i16o, | |
diff --git a/tests/benchdnn/dnn_types.cpp b/tests/benchdnn/dnn_types.cpp | |
index e0c3ca44..e721ac9d 100644 | |
--- a/tests/benchdnn/dnn_types.cpp | |
+++ b/tests/benchdnn/dnn_types.cpp | |
@@ -102,6 +102,7 @@ data_kind_t fmt2data_kind(mkldnn_memory_format_t fmt) { | |
case mkldnn_gOIw8i16o2i: | |
case mkldnn_goihw: | |
case mkldnn_hwigo: | |
+ case mkldnn_giohw: | |
case mkldnn_hwigo_s8s8: | |
case mkldnn_gOIhw8i8o: | |
case mkldnn_gOIhw16i16o: | |
diff --git a/tests/benchdnn/mkldnn_debug.cpp b/tests/benchdnn/mkldnn_debug.cpp | |
index decf41bf..840c556b 100644 | |
--- a/tests/benchdnn/mkldnn_debug.cpp | |
+++ b/tests/benchdnn/mkldnn_debug.cpp | |
@@ -96,6 +96,7 @@ mkldnn_memory_format_t str2fmt(const char *str) { | |
CASE(oihw); | |
CASE(ihwo); | |
CASE(hwio); | |
+ CASE(iohw); | |
CASE(hwio_s8s8); | |
CASE(dhwio); | |
CASE(OIhw8i8o); | |
@@ -114,6 +115,7 @@ mkldnn_memory_format_t str2fmt(const char *str) { | |
CASE(goiw); | |
CASE(goihw); | |
CASE(hwigo); | |
+ CASE(giohw); | |
CASE(hwigo_s8s8); | |
CASE(goiw); | |
CASE(gOIw16i16o); | |
diff --git a/tests/gtests/mkldnn_test_common.hpp b/tests/gtests/mkldnn_test_common.hpp | |
index f2900833..de69dbcb 100644 | |
--- a/tests/gtests/mkldnn_test_common.hpp | |
+++ b/tests/gtests/mkldnn_test_common.hpp | |
@@ -237,6 +237,7 @@ inline mkldnn::memory::desc create_md(mkldnn::memory::dims dims, | |
case f::nChw16c: | |
case f::oihw: | |
case f::hwio: | |
+ case f::iohw: | |
case f::oIhw8i: | |
case f::oIhw16i: | |
case f::OIhw8i8o: | |
@@ -258,6 +259,7 @@ inline mkldnn::memory::desc create_md(mkldnn::memory::dims dims, | |
case f::oidhw: | |
case f::goihw: | |
case f::hwigo: | |
+ case f::giohw: | |
case f::oIdhw8i: | |
case f::oIdhw16i: | |
case f::OIdhw8i8o: | |
diff --git a/tests/gtests/test_reorder.cpp b/tests/gtests/test_reorder.cpp | |
index e182e91c..89bb4bf3 100644 | |
--- a/tests/gtests/test_reorder.cpp | |
+++ b/tests/gtests/test_reorder.cpp | |
@@ -333,7 +333,11 @@ TEST_P(reorder_simple_test_weights_f32_f32_1, TestsReorder) { } | |
INSTANTIATE_TEST_CASE_P(TestReorder, reorder_simple_test_weights_f32_f32_1, | |
::testing::Values( | |
cfg_f32{eng::cpu, fmt::goihw, fmt::Goihw16g, {32, 32, 32, 3, 3}}, | |
- cfg_f32{eng::cpu, fmt::Goihw16g, fmt::goihw, {32, 32, 32, 3, 3}} | |
+ cfg_f32{eng::cpu, fmt::Goihw16g, fmt::goihw, {32, 32, 32, 3, 3}}, | |
+ cfg_f32{eng::cpu, fmt::oihw, fmt::iohw, {32, 32, 3, 3}}, | |
+ cfg_f32{eng::cpu, fmt::iohw, fmt::oihw, {32, 32, 3, 3}}, | |
+ cfg_f32{eng::cpu, fmt::goihw, fmt::giohw, {2, 32, 32, 3, 3}}, | |
+ cfg_f32{eng::cpu, fmt::giohw, fmt::goihw, {2, 32, 32, 3, 3}} | |
) | |
); | |
-- | |
2.14.5 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment