Skip to content

Instantly share code, notes, and snippets.

@kice
Last active November 20, 2018 11:55
Show Gist options
  • Save kice/8ca12acaed41b06ba3d096ee3c63e143 to your computer and use it in GitHub Desktop.
Save kice/8ca12acaed41b06ba3d096ee3c63e143 to your computer and use it in GitHub Desktop.
#include <string>
#include <vector>
#include <assert.h>
#include <opencv/cv.hpp>
template<class _Elem,
class _Traits,
class _Alloc> inline
std::vector<std::basic_string<_Elem, _Traits, _Alloc>> split(
std::basic_string<_Elem, _Traits, _Alloc>& _Str,
const _Elem _Delim)
{
std::vector<std::basic_string<_Elem, _Traits, _Alloc>> elems;
std::basic_stringstream<_Elem, _Traits, _Alloc> ss(_Str);
std::basic_string<_Elem, _Traits, _Alloc> item;
while (std::getline(ss, item, _Delim)) {
elems.push_back(item);
}
return elems;
}
#include <mxnet-cpp/MxNetCpp.h>
using namespace mxnet::cpp;
int main()
{
cv::Mat img = cv::imread("2631_x2_HR.png");
img.convertTo(img, CV_32FC3, 1 / 255.);
std::vector<cv::Mat> rgb(3);
cv::split(img, rgb);
size_t size = img.rows * img.cols;
mx_float *data = new mx_float[size * img.channels()];
memcpy(data, rgb[2].data, size);
memcpy(data + size, rgb[1].data, size);
memcpy(data + size*2, rgb[0].data, size);
auto ctx = Context::gpu(1);
Symbol net = Symbol::Load("int8-symbol.json");
std::map<std::string, NDArray> params;
NDArray::Load("int8-0000.params", nullptr, &params);
std::map<std::string, NDArray> _arg_map;
std::map<std::string, NDArray> _aux_map;
NDArray ndata = NDArray(data,
{ 1, index_t(img.channels()), index_t(img.rows), index_t(img.cols) }, ctx);
_arg_map["data"] = ndata;
for (auto &v : params) {
auto name = v.first;
auto data = v.second;
auto l = split(name, ':');
if (l[0] == "arg") {
_arg_map[l[1]] = data;
} else if (l[0] == "aux") {
_aux_map[l[1]] = data;
}
}
try {
std::vector<const char *> map_keys;
std::vector<int> dev_types, dev_ids;
ExecutorHandle *shared_exec_handle = nullptr;
std::vector<const char *> arg_shape_names = { "data" };
std::vector<mx_uint> input_shape_indptr = { 0, 4 };
std::vector<mx_uint> input_shape_data =
{
1,
static_cast<mx_uint>(img.channels()),
static_cast<mx_uint>(img.rows),
static_cast<mx_uint>(img.cols)
};
std::vector<const char *> arg_dtypes_name = { "data" };
std::vector<int> arg_dtypes = { 0 };
std::vector<const char *> arg_stype_names;
std::vector<int> arg_stypes{};
mx_uint size = 0;
const char **sarr = nullptr;
MXSymbolListArguments(net.GetHandle(), &size, &sarr);
if (size == 0) {
throw dmlc::Error("MXSymbolListArguments(net.GetHandle(), &size, &sarr); Error");
}
std::vector<std::string> arg_names;
for (int i = 0; i < size; ++i) {
arg_names.push_back(sarr[i]);
}
++sarr;
--size;
std::vector<int> shared_buffer_len = { 0 };
std::vector<const char *> shared_buffer_name_list = { nullptr };
std::vector<NDArrayHandle> shared_buffer_handle_list = { nullptr };
const char **updated_shared_buffer_name_list = nullptr;
NDArrayHandle *updated_shared_buffer_handles = nullptr;
mx_uint num_in_args = 0;
NDArrayHandle *in_args = nullptr, *arg_grads = nullptr;
mx_uint num_aux_states = 0;
NDArrayHandle *aux_states = nullptr;
ExecutorHandle handle = 0;
assert(MXExecutorSimpleBind(
net.GetHandle(), ctx.GetDeviceType(), ctx.GetDeviceId(),
0, map_keys.data(), dev_types.data(), dev_ids.data(),
0, nullptr, nullptr,
arg_shape_names.size(), arg_shape_names.data(), input_shape_data.data(), input_shape_indptr.data(),
arg_dtypes_name.size(), arg_dtypes_name.data(), arg_dtypes.data(),
0, arg_stype_names.data(), arg_stypes.data(),
size, sarr,
shared_buffer_len.data(), shared_buffer_name_list.data(), shared_buffer_handle_list.data(),
&updated_shared_buffer_name_list, &updated_shared_buffer_handles,
&num_in_args, &in_args, &arg_grads,
&num_aux_states, &aux_states,
nullptr, &handle
) == 0);
assert(handle != nullptr);
for (int i = 0; i < num_in_args; ++i) {
NDArray arg_dst = NDArray(in_args[i]);
NDArray arg_param = _arg_map[arg_names[i]];
auto shape1 = arg_dst.GetShape();
auto shape2 = arg_param.GetShape();
assert(shape1 == shape2);
auto dtype1 = arg_dst.GetDType();
auto dtype2 = arg_param.GetDType();
assert(dtype1 == dtype2);
arg_param.CopyTo(&arg_dst);
}
auto aux_names = net.ListAuxiliaryStates();
for (int i = 0; i < num_aux_states; ++i) {
NDArray aux_dst = NDArray(in_args[i]);
NDArray aux_param = _aux_map[aux_names[i]];
auto shape1 = aux_dst.GetShape();
auto shape2 = aux_param.GetShape();
assert(shape1 == shape2);
auto dtype1 = aux_dst.GetDType();
auto dtype2 = aux_param.GetDType();
assert(dtype1 == dtype2);
aux_param.CopyTo(&aux_dst);
}
assert(MXExecutorForward(handle, false) == 0);
NDArray::WaitAll();
NDArrayHandle *out;
mx_uint out_size;
assert(MXExecutorOutputs(handle, &out_size, &out) == 0);
NDArray res = NDArray(out[0]);
mx_float *res_data = new mx_float[res.Size()];
memset(res_data, 0, res.Size() * sizeof(mx_float));
res.SyncCopyToCPU(res_data);
MXExecutorFree(handle);
} catch (const dmlc::Error& err) {
printf("%s\n", MXGetLastError());
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment