Created
August 8, 2018 07:24
-
-
Save log0div0/d8928d6633062048a732468b32536a9f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <chrono> | |
#include <stdexcept> | |
#include <memory> | |
#include <fstream> | |
#include <iterator> | |
#include <algorithm> | |
#include <mxnet/c_predict_api.h> | |
#include "lodepng.h" | |
struct pixel_t { | |
uint8_t r; | |
uint8_t g; | |
uint8_t b; | |
uint8_t a; | |
float get_gray() const { | |
return float(r + g + b) / float(256 + 256 + 256); | |
} | |
}; | |
char chars[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ "; | |
constexpr size_t size_of_chars = sizeof(chars) - 1; | |
#pragma pack(push, 1) | |
struct prediction_t { | |
float probability[size_of_chars]; | |
char get_char() const { | |
auto it = std::max_element(std::begin(probability), std::end(probability)); | |
auto pos = it - std::begin(probability); | |
return chars[pos]; | |
} | |
}; | |
#pragma pack(pop) | |
int main(int argc, char* argv[]) { | |
try { | |
// load data | |
std::vector<uint8_t> image; | |
unsigned width, height; | |
std::string image_file_path = "1.png"; | |
unsigned error = lodepng::decode(image, width, height, image_file_path); | |
if (error) { | |
throw std::runtime_error(lodepng_error_text(error)); | |
} | |
unsigned channels = image.size() / (width * height); | |
pixel_t* pixels = (pixel_t*)image.data(); | |
std::cout << width << "x" << height << "x" << channels << std::endl; | |
unsigned char_width = 8; | |
unsigned char_height = 16; | |
unsigned char_size = char_width * char_height; | |
unsigned chars_per_row = width / char_width; | |
unsigned chars_per_column = height / char_height; | |
unsigned chars_count = chars_per_column * chars_per_row; | |
int dev_type = 1; // 1: cpu, 2: gpu | |
int dev_id = 0; | |
mx_uint num_input_nodes = 1; // 1 for feedforward | |
const char* input_key[1] = { "data" }; | |
const char** input_keys = input_key; | |
const mx_uint input_shape_indptr[2] = { 0, 4 }; | |
const mx_uint input_shape_data[4] = { | |
chars_count, | |
1, | |
char_height, | |
char_width | |
}; | |
std::string json_file_path = "net-symbol.json"; | |
std::ifstream json_file(json_file_path, std::ios::in); | |
if (!json_file) { | |
throw std::runtime_error("failed to load " + json_file_path); | |
} | |
std::string json((std::istreambuf_iterator<char>(json_file)), std::istreambuf_iterator<char>()); | |
json_file.close(); | |
std::string params_file_path = "net-0000.params"; | |
std::ifstream params_file(params_file_path, std::ios::binary); | |
if (!params_file) { | |
throw std::runtime_error("failed to load " + params_file_path); | |
} | |
std::string params((std::istreambuf_iterator<char>(params_file)), std::istreambuf_iterator<char>()); | |
params_file.close(); | |
PredictorHandle predictor = nullptr; | |
MXPredCreate( | |
static_cast<const char*>(json.data()), | |
static_cast<const char*>(params.data()), | |
static_cast<int>(params.size()), | |
dev_type, | |
dev_id, | |
num_input_nodes, | |
input_keys, | |
input_shape_indptr, | |
input_shape_data, | |
&predictor); | |
if (!predictor) { | |
throw std::runtime_error(MXGetLastError()); | |
} | |
std::cout << "predictor created" << std::endl; | |
std::vector<mx_float> in(chars_count * 1 * char_height * char_width); | |
std::vector<mx_float> out(chars_count * size_of_chars); | |
prediction_t* predictions = (prediction_t*)out.data(); | |
std::string result(chars_count + chars_per_column, '\0'); | |
auto start = std::chrono::high_resolution_clock::now(); | |
size_t index = 0; | |
for (size_t row = 0; row < chars_per_column; ++row) { | |
for (size_t column = 0; column < chars_per_row; ++ column) { | |
for (size_t y = 0; y < char_height; ++y) { | |
for (size_t x = 0; x < char_width; ++x) { | |
in[index * char_size + y * char_width + x] = pixels[(row * char_height + y) * width + column * char_width + x].get_gray(); | |
} | |
} | |
++index; | |
} | |
} | |
MXPredSetInput(predictor, "data", in.data(), static_cast<mx_uint>(in.size())); | |
MXPredForward(predictor); | |
mx_uint output_index = 0; | |
MXPredGetOutput(predictor, output_index, &(out[0]), static_cast<mx_uint>(out.size())); | |
size_t i = 0; | |
size_t j = 0; | |
for (size_t row = 0; row < chars_per_column; ++row) { | |
for (size_t column = 0; column < chars_per_row; ++ column) { | |
result[i++] = predictions[j++].get_char(); | |
} | |
result[i++] = '\n'; | |
} | |
auto end = std::chrono::high_resolution_clock::now(); | |
std::cout << result << std::endl; | |
std::chrono::duration<double> diff = end-start; | |
std::cout << diff.count() << std::endl; | |
MXPredFree(predictor); | |
} | |
catch (const std::exception& error) { | |
std::cout << error.what() << std::endl; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment