Skip to content

Instantly share code, notes, and snippets.

@SafeteeWoW
Created January 3, 2019 07:04
Show Gist options
  • Save SafeteeWoW/d5a105041e88ba47b47bf2990d26ebc2 to your computer and use it in GitHub Desktop.
Save SafeteeWoW/d5a105041e88ba47b47bf2990d26ebc2 to your computer and use it in GitHub Desktop.
Try to test signed int8 convolution on Intel MKL.
#include <iostream>
#include <cstdint>
#include <numeric>
#include "mkldnn.hpp"
using namespace mkldnn;
using namespace std;
std::vector<int32_t> run_int8_conv(memory::data_type src_data_type) {
auto cpu_engine = engine(engine::cpu, 0);
std::vector<primitive> net;
std::vector<primitive> net_weight;
memory::dims conv_src_tz = { 1, 3, 64, 64 };
memory::dims conv_weights_tz = { 8, 3, 5, 5 };
memory::dims conv_bias_tz = { 8 };
memory::dims conv_dst_tz = { 1, 8, 60, 60 };
memory::dims conv_strides = { 1, 1 };
auto conv_padding = { 0, 0 };
auto conv_src_md = memory::desc(
{ conv_src_tz }, src_data_type, memory::format::any);
auto conv_weights_md = memory::desc(
{ conv_weights_tz }, memory::data_type::s8, memory::format::any);
auto conv_bias_md = memory::desc(
{ conv_bias_tz }, memory::data_type::s32, memory::format::any);
auto conv_dst_md = memory::desc(
{ conv_dst_tz }, memory::data_type::s32, memory::format::any);
std::vector<int8_t> conv_src(std::accumulate(conv_src_tz.begin(),
conv_src_tz.end(), 1, std::multiplies<uint32_t>()));
std::vector<int8_t> conv_weights(std::accumulate(conv_weights_tz.begin(),
conv_weights_tz.end(), 1, std::multiplies<uint32_t>()));
std::vector<int32_t> conv_bias(std::accumulate(conv_bias_tz.begin(),
conv_bias_tz.end(), 1, std::multiplies<uint32_t>()));
std::vector<int32_t> conv_dst(std::accumulate(conv_dst_tz.begin(),
conv_dst_tz.end(), 1, std::multiplies<uint32_t>()));
// Fill src, weights and bias with constant values.
std::fill(conv_src.begin(), conv_src.end(), 1);
std::fill(conv_weights.begin(), conv_weights.end(), 1);
std::fill(conv_bias.begin(), conv_bias.end(), 0);
auto conv_desc = convolution_forward::desc(prop_kind::forward,
convolution_direct, conv_src_md, conv_weights_md, conv_bias_md,
conv_dst_md, conv_strides, conv_padding, conv_padding,
padding_kind::zero);
auto conv_prim_desc
= convolution_forward::primitive_desc(conv_desc, cpu_engine);
auto user_src_memory
= memory({ { { conv_src_tz }, src_data_type,
memory::format::nhwc },
cpu_engine },
conv_src.data());
auto user_weights_memory
= memory({ { { conv_weights_tz }, memory::data_type::s8,
memory::format::oihw },
cpu_engine },
conv_weights.data());
auto conv_bias_memory = memory(
{ { { conv_bias_tz }, memory::data_type::s32, memory::format::x },
cpu_engine },
conv_bias.data());
auto user_dst_memory = memory(
{ { { conv_dst_tz }, memory::data_type::s32, memory::format::nhwc },
cpu_engine }, conv_dst.data());
auto conv_src_memory = user_src_memory;
if (memory::primitive_desc(conv_prim_desc.src_primitive_desc())
!= user_src_memory.get_primitive_desc()) {
conv_src_memory = memory(conv_prim_desc.src_primitive_desc());
net.push_back(reorder(user_src_memory, conv_src_memory));
}
auto conv_weights_memory = user_weights_memory;
if (memory::primitive_desc(conv_prim_desc.weights_primitive_desc())
!= user_weights_memory.get_primitive_desc()) {
conv_weights_memory
= memory(conv_prim_desc.weights_primitive_desc());
net_weight.push_back(
reorder(user_weights_memory, conv_weights_memory));
}
auto conv_dst_memory = memory(conv_prim_desc.dst_primitive_desc());
net.push_back(convolution_forward(conv_prim_desc, conv_src_memory,
conv_weights_memory, conv_bias_memory,
conv_dst_memory));
if (conv_dst_memory != user_dst_memory) {
net.push_back(reorder(conv_dst_memory, user_dst_memory));
}
stream(stream::kind::eager).submit(net_weight).wait();
stream(stream::kind::eager).submit(net).wait();
return conv_dst;
}
int main(int argc, char **argv) {
try {
auto conv_dst_u8 = run_int8_conv(memory::data_type::u8);
cout << "result[0][0][0][0] " << conv_dst_u8[0] << std::endl;
auto conv_dst_s8 = run_int8_conv(memory::data_type::s8);
cout << "result[0][0][0][0] " << conv_dst_s8[0] << std::endl;
} catch (error &e) {
std::cerr << "status: " << e.status << std::endl;
std::cerr << "message: " << e.message << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment