Created
December 18, 2018 13:08
-
-
Save emfomenk/2c3cf85979fd52df17445181505c2f23 to your computer and use it in GitHub Desktop.
Simple batch normalization example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <math.h> | |
#include "mkldnn.hpp" | |
using namespace mkldnn; | |
void init_data(float *dat, int size) { | |
for (int i = 0; i < size; ++i) | |
dat[i] = 1.f + 2.f * sinf(0.2 * i); | |
} | |
void exec_bnrm_fwd() { | |
engine eng(engine::cpu, 0); | |
const int N = 4, C = 4, H = 27, W = 27; | |
const int sz = N * C * H * W; | |
float *src = (float *)malloc(sizeof(float) * sz); | |
init_data(src, sz); // ideally mean ~= 1.f, var ~= 2.f | |
float *dst = (float *)malloc(sizeof(float) * sz); | |
float *mean = (float *)malloc(sizeof(float) * C); | |
float *var = (float *)malloc(sizeof(float) * C); | |
memory::desc data_desc({N, C, H, W}, memory::f32, memory::nchw); | |
memory src_mem({data_desc, eng}, src); | |
memory dst_mem({data_desc, eng}, dst); | |
memory::desc stat_desc({C}, memory::f32, memory::x); | |
memory mean_mem({stat_desc, eng}, mean); | |
memory var_mem({stat_desc, eng}, var); | |
unsigned flags = 0; | |
// set flags for different flavors (use | to combine flags) | |
// use_global_stats -- do not compute mean and variance in the primitive, user has to provide them | |
// use_scale_shift -- in addition to batch norm also scale and shift the result | |
batch_normalization_forward::desc bnrm_fwd_d( | |
prop_kind::forward_training, // might be forward_inference, backward, backward_data | |
data_desc, // data descriptor (i.e. sizes, data type, and layout) | |
0.001f, // eps | |
flags); | |
batch_normalization_forward::primitive_desc bnrm_fwd_pd(bnrm_fwd_d, eng); | |
batch_normalization_forward bnrm_fwd(bnrm_fwd_pd, | |
src_mem, dst_mem, mean_mem, var_mem); | |
stream(stream::kind::eager).submit({bnrm_fwd}).wait(); // execute bnrm | |
for (int c = 0; c < C; ++c) { | |
printf("[%d] mean:%f var:%f\n", c, mean[c], var[c]); | |
} | |
} | |
int main() { | |
exec_bnrm_fwd(); | |
return 0; | |
} | |
// $ ( g++ example_bnrm_fwd.cpp -lmkldnn -lm && ./a.out ) | |
// [0] mean:0.994635 var:1.997512 | |
// [1] mean:1.000734 var:2.003489 | |
// [2] mean:1.005777 var:1.996543 | |
// [3] mean:1.002504 var:2.002276 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment