Created
January 3, 2019 07:04
-
-
Save SafeteeWoW/d5a105041e88ba47b47bf2990d26ebc2 to your computer and use it in GitHub Desktop.
Try to test signed int8 convolution on Intel MKL.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <cstdint> | |
#include <numeric> | |
#include "mkldnn.hpp" | |
using namespace mkldnn; | |
using namespace std; | |
std::vector<int32_t> run_int8_conv(memory::data_type src_data_type) { | |
auto cpu_engine = engine(engine::cpu, 0); | |
std::vector<primitive> net; | |
std::vector<primitive> net_weight; | |
memory::dims conv_src_tz = { 1, 3, 64, 64 }; | |
memory::dims conv_weights_tz = { 8, 3, 5, 5 }; | |
memory::dims conv_bias_tz = { 8 }; | |
memory::dims conv_dst_tz = { 1, 8, 60, 60 }; | |
memory::dims conv_strides = { 1, 1 }; | |
auto conv_padding = { 0, 0 }; | |
auto conv_src_md = memory::desc( | |
{ conv_src_tz }, src_data_type, memory::format::any); | |
auto conv_weights_md = memory::desc( | |
{ conv_weights_tz }, memory::data_type::s8, memory::format::any); | |
auto conv_bias_md = memory::desc( | |
{ conv_bias_tz }, memory::data_type::s32, memory::format::any); | |
auto conv_dst_md = memory::desc( | |
{ conv_dst_tz }, memory::data_type::s32, memory::format::any); | |
std::vector<int8_t> conv_src(std::accumulate(conv_src_tz.begin(), | |
conv_src_tz.end(), 1, std::multiplies<uint32_t>())); | |
std::vector<int8_t> conv_weights(std::accumulate(conv_weights_tz.begin(), | |
conv_weights_tz.end(), 1, std::multiplies<uint32_t>())); | |
std::vector<int32_t> conv_bias(std::accumulate(conv_bias_tz.begin(), | |
conv_bias_tz.end(), 1, std::multiplies<uint32_t>())); | |
std::vector<int32_t> conv_dst(std::accumulate(conv_dst_tz.begin(), | |
conv_dst_tz.end(), 1, std::multiplies<uint32_t>())); | |
// Fill src, weights and bias with constant values. | |
std::fill(conv_src.begin(), conv_src.end(), 1); | |
std::fill(conv_weights.begin(), conv_weights.end(), 1); | |
std::fill(conv_bias.begin(), conv_bias.end(), 0); | |
auto conv_desc = convolution_forward::desc(prop_kind::forward, | |
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md, | |
conv_dst_md, conv_strides, conv_padding, conv_padding, | |
padding_kind::zero); | |
auto conv_prim_desc | |
= convolution_forward::primitive_desc(conv_desc, cpu_engine); | |
auto user_src_memory | |
= memory({ { { conv_src_tz }, src_data_type, | |
memory::format::nhwc }, | |
cpu_engine }, | |
conv_src.data()); | |
auto user_weights_memory | |
= memory({ { { conv_weights_tz }, memory::data_type::s8, | |
memory::format::oihw }, | |
cpu_engine }, | |
conv_weights.data()); | |
auto conv_bias_memory = memory( | |
{ { { conv_bias_tz }, memory::data_type::s32, memory::format::x }, | |
cpu_engine }, | |
conv_bias.data()); | |
auto user_dst_memory = memory( | |
{ { { conv_dst_tz }, memory::data_type::s32, memory::format::nhwc }, | |
cpu_engine }, conv_dst.data()); | |
auto conv_src_memory = user_src_memory; | |
if (memory::primitive_desc(conv_prim_desc.src_primitive_desc()) | |
!= user_src_memory.get_primitive_desc()) { | |
conv_src_memory = memory(conv_prim_desc.src_primitive_desc()); | |
net.push_back(reorder(user_src_memory, conv_src_memory)); | |
} | |
auto conv_weights_memory = user_weights_memory; | |
if (memory::primitive_desc(conv_prim_desc.weights_primitive_desc()) | |
!= user_weights_memory.get_primitive_desc()) { | |
conv_weights_memory | |
= memory(conv_prim_desc.weights_primitive_desc()); | |
net_weight.push_back( | |
reorder(user_weights_memory, conv_weights_memory)); | |
} | |
auto conv_dst_memory = memory(conv_prim_desc.dst_primitive_desc()); | |
net.push_back(convolution_forward(conv_prim_desc, conv_src_memory, | |
conv_weights_memory, conv_bias_memory, | |
conv_dst_memory)); | |
if (conv_dst_memory != user_dst_memory) { | |
net.push_back(reorder(conv_dst_memory, user_dst_memory)); | |
} | |
stream(stream::kind::eager).submit(net_weight).wait(); | |
stream(stream::kind::eager).submit(net).wait(); | |
return conv_dst; | |
} | |
int main(int argc, char **argv) { | |
try { | |
auto conv_dst_u8 = run_int8_conv(memory::data_type::u8); | |
cout << "result[0][0][0][0] " << conv_dst_u8[0] << std::endl; | |
auto conv_dst_s8 = run_int8_conv(memory::data_type::s8); | |
cout << "result[0][0][0][0] " << conv_dst_s8[0] << std::endl; | |
} catch (error &e) { | |
std::cerr << "status: " << e.status << std::endl; | |
std::cerr << "message: " << e.message << std::endl; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment