Last active
March 20, 2019 09:11
-
-
Save emfomenk/f2d6f328ce8787406ed9c3b9c657462c to your computer and use it in GitHub Desktop.
Simple concat example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include "mkldnn.hpp" | |
using namespace mkldnn; | |
using format = mkldnn::memory::format; | |
void init_data(float *dat, int size, float v) { | |
for (int i = 0; i < size; ++i) dat[i] = v; | |
} | |
void exec_concat(format fmt1, format fmt2, format fmt, bool enforce_dst) { | |
engine eng(engine::cpu, 0); | |
const int N = 2, H = 3, W = 3; | |
const int C1 = 16, C2 = 24, C = C1 + C2; | |
const int sz1 = N * C1 * H * W; | |
const int sz2 = N * C2 * H * W; | |
const int sz = N * C * H * W; | |
float *src1 = (float *)malloc(sizeof(float) * sz1); init_data(src1, sz1, 1); | |
float *src2 = (float *)malloc(sizeof(float) * sz2); init_data(src2, sz2, 2); | |
memory::desc src1_md({N, C1, H, W}, memory::f32, fmt1); | |
memory::desc src2_md({N, C2, H, W}, memory::f32, fmt2); | |
memory src1_m({src1_md, eng}, src1); | |
memory src2_m({src2_md, eng}, src2); | |
memory::desc enforced_dst_md({N, C, H, W}, memory::f32, fmt); | |
auto concat_pd = enforce_dst | |
? concat::primitive_desc(enforced_dst_md, 1, {{src1_md, eng}, {src2_md, eng}}) | |
: concat::primitive_desc(1, {{src1_md, eng}, {src2_md, eng}}); | |
// let's figure out what dst looks like | |
memory::desc dst_md = concat_pd.dst_primitive_desc().desc(); | |
printf("formats src1:%d src2:%d dst:%d\n", | |
(int)src1_md.data.format, | |
(int)src2_md.data.format, | |
(int)dst_md.data.format); | |
memory dst_m({dst_md, eng}); | |
std::vector<primitive::at> inputs = {{src1_m}, {src2_m}}; | |
concat concat_p(concat_pd, inputs, dst_m); | |
stream(stream::kind::eager).submit({concat_p}).wait(); // execute concat | |
// check | |
const float *dst = (const float *)dst_m.get_data_handle(); | |
bool ok = true; | |
if (dst_md.data.format == mkldnn_nchw) { | |
for (int n = 0; n < N; ++n) | |
for (int c = 0; c < C; ++c) | |
for (int h = 0; h < H; ++h) | |
for (int w = 0; w < W; ++w) | |
if (*dst++ != (c < C1 ? 1 : 2)) ok = false; | |
} else if (dst_md.data.format == mkldnn_nChw8c) { | |
for (int n = 0; n < N; ++n) | |
for (int c_b = 0; c_b < C / 8; ++c_b) | |
for (int h = 0; h < H; ++h) | |
for (int w = 0; w < W; ++w) | |
for (int c8 = 0; c8 < 8; ++c8) | |
if (*dst++ != (c_b * 8 + c8 < C1 ? 1 : 2)) ok = false; | |
} else { | |
printf("no validation code is available"); | |
} | |
printf("example %s\n\n", ok ? "passed" : "failed"); | |
} | |
int main() { | |
printf("dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd)\n"); | |
exec_concat(format::nchw, format::nChw8c, format::any, false); | |
printf("dst is defined by concat, use regular API, but dst.fmt == any\n"); | |
exec_concat(format::nchw, format::nChw8c, format::any, true); | |
printf("dst is defined by user, use regular API\n"); | |
exec_concat(format::nchw, format::nChw8c, format::nchw, true); | |
return 0; | |
} | |
// $ g++ cpp_concat.cpp -lmkldnn && MKLDNN_VERBOSE=1 ./a.out ) | |
// | |
// dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd) | |
// formats src1:7 src2:35 dst:35 | |
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nChw8c,num:2,2x40x3x3,0.195068 | |
// example passed | |
// | |
// dst is defined by concat, use regular API, but dst.fmt == any | |
// formats src1:7 src2:35 dst:35 | |
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nChw8c,num:2,2x40x3x3,0.00585938 | |
// example passed | |
// | |
// dst is defined by user, use regular API | |
// formats src1:7 src2:35 dst:7 | |
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x40x3x3,0.15918 | |
// example passed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <iostream> | |
#include "mkldnn.hpp" | |
using namespace mkldnn; | |
using format = mkldnn::memory::format; | |
void init_data(float *dat, int size, float v) { | |
for (int i = 0; i < size; ++i) dat[i] = v; | |
} | |
void exec_concat(format fmt1, format fmt2, format fmt, bool enforce_dst) { | |
engine eng(engine::cpu, 0); | |
const int N = 2, H = 3, W = 3; | |
const int C1 = 16, C2 = 14, C = C1 + C2; | |
memory::desc src1_md({N, C1, H, W}, memory::f32, fmt1); | |
memory::desc src2_md({N, C2, H, W}, memory::f32, fmt2); | |
memory::primitive_desc src1_mpd(src1_md, eng); | |
memory::primitive_desc src2_mpd(src2_md, eng); | |
const int sz1 = src1_mpd.get_size(); | |
const int sz2 = src2_mpd.get_size(); | |
const int sz = sz1 + sz2; | |
float *src1 = (float *)malloc(sizeof(float) * sz1); init_data(src1, sz1, 1); | |
float *src2 = (float *)malloc(sizeof(float) * sz2); init_data(src2, sz2, 2); | |
memory src1_m({src1_md, eng}, src1); | |
memory src2_m({src2_md, eng}, src2); | |
memory::desc enforced_dst_md({N, C, H, W}, memory::f32, fmt); | |
auto concat_pd = enforce_dst | |
? concat::primitive_desc(enforced_dst_md, 1, {{src1_md, eng}, {src2_md, eng}}) | |
: concat::primitive_desc(1, {{src1_md, eng}, {src2_md, eng}}); | |
// let's figure out what dst is look like | |
memory::desc dst_md = concat_pd.dst_primitive_desc().desc(); | |
printf("formats src1:%d src2:%d dst:%d\n", | |
(int)src1_md.data.format, | |
(int)src2_md.data.format, | |
(int)dst_md.data.format); | |
memory dst_m({dst_md, eng}); | |
std::vector<primitive::at> inputs = {{src1_m}, {src2_m}}; | |
concat concat_p(concat_pd, inputs, dst_m); | |
stream(stream::kind::eager).submit({concat_p}).wait(); // execute concat | |
const float *dst = (const float *)dst_m.get_data_handle(); | |
bool ok = true; | |
if (dst_md.data.format == mkldnn_nchw) { | |
for (int n = 0; n < N; ++n) | |
for (int c = 0; c < C; ++c) | |
for (int h = 0; h < H; ++h) | |
for (int w = 0; w < W; ++w) | |
if (*dst++ != (c < C1 ? 1 : 2)) ok = false; | |
} else if (dst_md.data.format == mkldnn_nChw8c) { | |
for (int n = 0; n < N; ++n) | |
for (int c_b = 0; c_b < (C + 8 - 1) / 8; ++c_b) | |
for (int h = 0; h < H; ++h) | |
for (int w = 0; w < W; ++w) | |
for (int c8 = 0; c8 < 8; ++c8) | |
{ | |
const int c = c_b * 8 + c8; | |
if (c < C1 && *dst != 1) ok = false; | |
else if (c < C1 + C2 && *dst != 2) ok = false; | |
else if (*dst != 0) ok = false; // tail | |
dst++; | |
} | |
} else { | |
printf("no validation code is available"); | |
} | |
printf("example %s\n\n", ok ? "passed" : "failed"); | |
} | |
int main() { | |
printf("dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd)\n"); | |
exec_concat(format::nchw, format::nChw8c, format::any, false); | |
printf("dst is defined by concat, use regular API, but dst.fmt == any\n"); | |
exec_concat(format::nchw, format::nChw8c, format::any, true); | |
printf("dst is defined by user, use regular API (plain)\n"); | |
exec_concat(format::nchw, format::nChw8c, format::nchw, true); | |
printf("dst is defined by user, use regular API (blocked)\n"); | |
try { | |
exec_concat(format::nchw, format::nChw8c, format::nChw8c, true); | |
} catch (const mkldnn::error &e) { | |
std::cout << "error: " << e.message << std::endl; | |
} | |
return 0; | |
} | |
// $ g++ cpp_concat.cpp -lmkldnn && MKLDNN_VERBOSE=1 ./a.out ) | |
// | |
// dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd) | |
// formats src1:7 src2:35 dst:7 | |
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x30x3x3,0.258789 | |
// example passed | |
// | |
// dst is defined by concat, use regular API, but dst.fmt == any | |
// formats src1:7 src2:35 dst:7 | |
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x30x3x3,0.129883 | |
// example passed | |
// | |
// dst is defined by user, use regular API (plain) | |
// formats src1:7 src2:35 dst:7 | |
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x30x3x3,0.198975 | |
// example passed | |
// | |
// dst is defined by user, use regular API (blocked) | |
// error: could not create a concat primitive descriptor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment