Skip to content

Instantly share code, notes, and snippets.

@emfomenk
Last active March 20, 2019 09:11
Show Gist options
  • Save emfomenk/f2d6f328ce8787406ed9c3b9c657462c to your computer and use it in GitHub Desktop.
Save emfomenk/f2d6f328ce8787406ed9c3b9c657462c to your computer and use it in GitHub Desktop.
Simple concat example
#include <stdio.h>
#include "mkldnn.hpp"
using namespace mkldnn;
using format = mkldnn::memory::format;
void init_data(float *dat, int size, float v) {
for (int i = 0; i < size; ++i) dat[i] = v;
}
void exec_concat(format fmt1, format fmt2, format fmt, bool enforce_dst) {
engine eng(engine::cpu, 0);
const int N = 2, H = 3, W = 3;
const int C1 = 16, C2 = 24, C = C1 + C2;
const int sz1 = N * C1 * H * W;
const int sz2 = N * C2 * H * W;
const int sz = N * C * H * W;
float *src1 = (float *)malloc(sizeof(float) * sz1); init_data(src1, sz1, 1);
float *src2 = (float *)malloc(sizeof(float) * sz2); init_data(src2, sz2, 2);
memory::desc src1_md({N, C1, H, W}, memory::f32, fmt1);
memory::desc src2_md({N, C2, H, W}, memory::f32, fmt2);
memory src1_m({src1_md, eng}, src1);
memory src2_m({src2_md, eng}, src2);
memory::desc enforced_dst_md({N, C, H, W}, memory::f32, fmt);
auto concat_pd = enforce_dst
? concat::primitive_desc(enforced_dst_md, 1, {{src1_md, eng}, {src2_md, eng}})
: concat::primitive_desc(1, {{src1_md, eng}, {src2_md, eng}});
// let's figure out what dst looks like
memory::desc dst_md = concat_pd.dst_primitive_desc().desc();
printf("formats src1:%d src2:%d dst:%d\n",
(int)src1_md.data.format,
(int)src2_md.data.format,
(int)dst_md.data.format);
memory dst_m({dst_md, eng});
std::vector<primitive::at> inputs = {{src1_m}, {src2_m}};
concat concat_p(concat_pd, inputs, dst_m);
stream(stream::kind::eager).submit({concat_p}).wait(); // execute concat
// check
const float *dst = (const float *)dst_m.get_data_handle();
bool ok = true;
if (dst_md.data.format == mkldnn_nchw) {
for (int n = 0; n < N; ++n)
for (int c = 0; c < C; ++c)
for (int h = 0; h < H; ++h)
for (int w = 0; w < W; ++w)
if (*dst++ != (c < C1 ? 1 : 2)) ok = false;
} else if (dst_md.data.format == mkldnn_nChw8c) {
for (int n = 0; n < N; ++n)
for (int c_b = 0; c_b < C / 8; ++c_b)
for (int h = 0; h < H; ++h)
for (int w = 0; w < W; ++w)
for (int c8 = 0; c8 < 8; ++c8)
if (*dst++ != (c_b * 8 + c8 < C1 ? 1 : 2)) ok = false;
} else {
printf("no validation code is available");
}
printf("example %s\n\n", ok ? "passed" : "failed");
}
int main() {
printf("dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd)\n");
exec_concat(format::nchw, format::nChw8c, format::any, false);
printf("dst is defined by concat, use regular API, but dst.fmt == any\n");
exec_concat(format::nchw, format::nChw8c, format::any, true);
printf("dst is defined by user, use regular API\n");
exec_concat(format::nchw, format::nChw8c, format::nchw, true);
return 0;
}
// $ g++ cpp_concat.cpp -lmkldnn && MKLDNN_VERBOSE=1 ./a.out )
//
// dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd)
// formats src1:7 src2:35 dst:35
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nChw8c,num:2,2x40x3x3,0.195068
// example passed
//
// dst is defined by concat, use regular API, but dst.fmt == any
// formats src1:7 src2:35 dst:35
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nChw8c,num:2,2x40x3x3,0.00585938
// example passed
//
// dst is defined by user, use regular API
// formats src1:7 src2:35 dst:7
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x40x3x3,0.15918
// example passed
#include <stdio.h>
#include <iostream>
#include "mkldnn.hpp"
using namespace mkldnn;
using format = mkldnn::memory::format;
void init_data(float *dat, int size, float v) {
for (int i = 0; i < size; ++i) dat[i] = v;
}
void exec_concat(format fmt1, format fmt2, format fmt, bool enforce_dst) {
engine eng(engine::cpu, 0);
const int N = 2, H = 3, W = 3;
const int C1 = 16, C2 = 14, C = C1 + C2;
memory::desc src1_md({N, C1, H, W}, memory::f32, fmt1);
memory::desc src2_md({N, C2, H, W}, memory::f32, fmt2);
memory::primitive_desc src1_mpd(src1_md, eng);
memory::primitive_desc src2_mpd(src2_md, eng);
const int sz1 = src1_mpd.get_size();
const int sz2 = src2_mpd.get_size();
const int sz = sz1 + sz2;
float *src1 = (float *)malloc(sizeof(float) * sz1); init_data(src1, sz1, 1);
float *src2 = (float *)malloc(sizeof(float) * sz2); init_data(src2, sz2, 2);
memory src1_m({src1_md, eng}, src1);
memory src2_m({src2_md, eng}, src2);
memory::desc enforced_dst_md({N, C, H, W}, memory::f32, fmt);
auto concat_pd = enforce_dst
? concat::primitive_desc(enforced_dst_md, 1, {{src1_md, eng}, {src2_md, eng}})
: concat::primitive_desc(1, {{src1_md, eng}, {src2_md, eng}});
// let's figure out what dst is look like
memory::desc dst_md = concat_pd.dst_primitive_desc().desc();
printf("formats src1:%d src2:%d dst:%d\n",
(int)src1_md.data.format,
(int)src2_md.data.format,
(int)dst_md.data.format);
memory dst_m({dst_md, eng});
std::vector<primitive::at> inputs = {{src1_m}, {src2_m}};
concat concat_p(concat_pd, inputs, dst_m);
stream(stream::kind::eager).submit({concat_p}).wait(); // execute concat
const float *dst = (const float *)dst_m.get_data_handle();
bool ok = true;
if (dst_md.data.format == mkldnn_nchw) {
for (int n = 0; n < N; ++n)
for (int c = 0; c < C; ++c)
for (int h = 0; h < H; ++h)
for (int w = 0; w < W; ++w)
if (*dst++ != (c < C1 ? 1 : 2)) ok = false;
} else if (dst_md.data.format == mkldnn_nChw8c) {
for (int n = 0; n < N; ++n)
for (int c_b = 0; c_b < (C + 8 - 1) / 8; ++c_b)
for (int h = 0; h < H; ++h)
for (int w = 0; w < W; ++w)
for (int c8 = 0; c8 < 8; ++c8)
{
const int c = c_b * 8 + c8;
if (c < C1 && *dst != 1) ok = false;
else if (c < C1 + C2 && *dst != 2) ok = false;
else if (*dst != 0) ok = false; // tail
dst++;
}
} else {
printf("no validation code is available");
}
printf("example %s\n\n", ok ? "passed" : "failed");
}
int main() {
printf("dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd)\n");
exec_concat(format::nchw, format::nChw8c, format::any, false);
printf("dst is defined by concat, use regular API, but dst.fmt == any\n");
exec_concat(format::nchw, format::nChw8c, format::any, true);
printf("dst is defined by user, use regular API (plain)\n");
exec_concat(format::nchw, format::nChw8c, format::nchw, true);
printf("dst is defined by user, use regular API (blocked)\n");
try {
exec_concat(format::nchw, format::nChw8c, format::nChw8c, true);
} catch (const mkldnn::error &e) {
std::cout << "error: " << e.message << std::endl;
}
return 0;
}
// $ g++ cpp_concat.cpp -lmkldnn && MKLDNN_VERBOSE=1 ./a.out )
//
// dst is defined by concat, use dedicated API for that (dst is not passed to concat::pd)
// formats src1:7 src2:35 dst:7
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x30x3x3,0.258789
// example passed
//
// dst is defined by concat, use regular API, but dst.fmt == any
// formats src1:7 src2:35 dst:7
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x30x3x3,0.129883
// example passed
//
// dst is defined by user, use regular API (plain)
// formats src1:7 src2:35 dst:7
// mkldnn_verbose,exec,concat,ref:any,undef,in:f32_nchw out:f32_nchw,num:2,2x30x3x3,0.198975
// example passed
//
// dst is defined by user, use regular API (blocked)
// error: could not create a concat primitive descriptor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment