Skip to content

Instantly share code, notes, and snippets.

@emfomenk
Created December 18, 2018 13:08
Show Gist options
  • Save emfomenk/2c3cf85979fd52df17445181505c2f23 to your computer and use it in GitHub Desktop.
Save emfomenk/2c3cf85979fd52df17445181505c2f23 to your computer and use it in GitHub Desktop.
Simple batch normalization example
#include <stdio.h>
#include <math.h>
#include "mkldnn.hpp"
using namespace mkldnn;
void init_data(float *dat, int size) {
for (int i = 0; i < size; ++i)
dat[i] = 1.f + 2.f * sinf(0.2 * i);
}
void exec_bnrm_fwd() {
engine eng(engine::cpu, 0);
const int N = 4, C = 4, H = 27, W = 27;
const int sz = N * C * H * W;
float *src = (float *)malloc(sizeof(float) * sz);
init_data(src, sz); // ideally mean ~= 1.f, var ~= 2.f
float *dst = (float *)malloc(sizeof(float) * sz);
float *mean = (float *)malloc(sizeof(float) * C);
float *var = (float *)malloc(sizeof(float) * C);
memory::desc data_desc({N, C, H, W}, memory::f32, memory::nchw);
memory src_mem({data_desc, eng}, src);
memory dst_mem({data_desc, eng}, dst);
memory::desc stat_desc({C}, memory::f32, memory::x);
memory mean_mem({stat_desc, eng}, mean);
memory var_mem({stat_desc, eng}, var);
unsigned flags = 0;
// set flags for different flavors (use | to combine flags)
// use_global_stats -- do not compute mean and variance in the primitive, user has to provide them
// use_scale_shift -- in addition to batch norm also scale and shift the result
batch_normalization_forward::desc bnrm_fwd_d(
prop_kind::forward_training, // might be forward_inference, backward, backward_data
data_desc, // data descriptor (i.e. sizes, data type, and layout)
0.001f, // eps
flags);
batch_normalization_forward::primitive_desc bnrm_fwd_pd(bnrm_fwd_d, eng);
batch_normalization_forward bnrm_fwd(bnrm_fwd_pd,
src_mem, dst_mem, mean_mem, var_mem);
stream(stream::kind::eager).submit({bnrm_fwd}).wait(); // execute bnrm
for (int c = 0; c < C; ++c) {
printf("[%d] mean:%f var:%f\n", c, mean[c], var[c]);
}
}
int main() {
exec_bnrm_fwd();
return 0;
}
// $ ( g++ example_bnrm_fwd.cpp -lmkldnn -lm && ./a.out )
// [0] mean:0.994635 var:1.997512
// [1] mean:1.000734 var:2.003489
// [2] mean:1.005777 var:1.996543
// [3] mean:1.002504 var:2.002276
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment