Skip to content

Instantly share code, notes, and snippets.

@log0div0
Created August 8, 2018 07:24
Show Gist options
  • Save log0div0/d8928d6633062048a732468b32536a9f to your computer and use it in GitHub Desktop.
Save log0div0/d8928d6633062048a732468b32536a9f to your computer and use it in GitHub Desktop.
#include <iostream>
#include <chrono>
#include <stdexcept>
#include <memory>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <mxnet/c_predict_api.h>
#include "lodepng.h"
struct pixel_t {
uint8_t r;
uint8_t g;
uint8_t b;
uint8_t a;
float get_gray() const {
return float(r + g + b) / float(256 + 256 + 256);
}
};
char chars[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ ";
constexpr size_t size_of_chars = sizeof(chars) - 1;
#pragma pack(push, 1)
struct prediction_t {
float probability[size_of_chars];
char get_char() const {
auto it = std::max_element(std::begin(probability), std::end(probability));
auto pos = it - std::begin(probability);
return chars[pos];
}
};
#pragma pack(pop)
int main(int argc, char* argv[]) {
try {
// load data
std::vector<uint8_t> image;
unsigned width, height;
std::string image_file_path = "1.png";
unsigned error = lodepng::decode(image, width, height, image_file_path);
if (error) {
throw std::runtime_error(lodepng_error_text(error));
}
unsigned channels = image.size() / (width * height);
pixel_t* pixels = (pixel_t*)image.data();
std::cout << width << "x" << height << "x" << channels << std::endl;
unsigned char_width = 8;
unsigned char_height = 16;
unsigned char_size = char_width * char_height;
unsigned chars_per_row = width / char_width;
unsigned chars_per_column = height / char_height;
unsigned chars_count = chars_per_column * chars_per_row;
int dev_type = 1; // 1: cpu, 2: gpu
int dev_id = 0;
mx_uint num_input_nodes = 1; // 1 for feedforward
const char* input_key[1] = { "data" };
const char** input_keys = input_key;
const mx_uint input_shape_indptr[2] = { 0, 4 };
const mx_uint input_shape_data[4] = {
chars_count,
1,
char_height,
char_width
};
std::string json_file_path = "net-symbol.json";
std::ifstream json_file(json_file_path, std::ios::in);
if (!json_file) {
throw std::runtime_error("failed to load " + json_file_path);
}
std::string json((std::istreambuf_iterator<char>(json_file)), std::istreambuf_iterator<char>());
json_file.close();
std::string params_file_path = "net-0000.params";
std::ifstream params_file(params_file_path, std::ios::binary);
if (!params_file) {
throw std::runtime_error("failed to load " + params_file_path);
}
std::string params((std::istreambuf_iterator<char>(params_file)), std::istreambuf_iterator<char>());
params_file.close();
PredictorHandle predictor = nullptr;
MXPredCreate(
static_cast<const char*>(json.data()),
static_cast<const char*>(params.data()),
static_cast<int>(params.size()),
dev_type,
dev_id,
num_input_nodes,
input_keys,
input_shape_indptr,
input_shape_data,
&predictor);
if (!predictor) {
throw std::runtime_error(MXGetLastError());
}
std::cout << "predictor created" << std::endl;
std::vector<mx_float> in(chars_count * 1 * char_height * char_width);
std::vector<mx_float> out(chars_count * size_of_chars);
prediction_t* predictions = (prediction_t*)out.data();
std::string result(chars_count + chars_per_column, '\0');
auto start = std::chrono::high_resolution_clock::now();
size_t index = 0;
for (size_t row = 0; row < chars_per_column; ++row) {
for (size_t column = 0; column < chars_per_row; ++ column) {
for (size_t y = 0; y < char_height; ++y) {
for (size_t x = 0; x < char_width; ++x) {
in[index * char_size + y * char_width + x] = pixels[(row * char_height + y) * width + column * char_width + x].get_gray();
}
}
++index;
}
}
MXPredSetInput(predictor, "data", in.data(), static_cast<mx_uint>(in.size()));
MXPredForward(predictor);
mx_uint output_index = 0;
MXPredGetOutput(predictor, output_index, &(out[0]), static_cast<mx_uint>(out.size()));
size_t i = 0;
size_t j = 0;
for (size_t row = 0; row < chars_per_column; ++row) {
for (size_t column = 0; column < chars_per_row; ++ column) {
result[i++] = predictions[j++].get_char();
}
result[i++] = '\n';
}
auto end = std::chrono::high_resolution_clock::now();
std::cout << result << std::endl;
std::chrono::duration<double> diff = end-start;
std::cout << diff.count() << std::endl;
MXPredFree(predictor);
}
catch (const std::exception& error) {
std::cout << error.what() << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment