Created
August 8, 2018 07:26
-
-
Save log0div0/f1d343da3c0dec9c44277c256ae2f0f7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <chrono> | |
#include <stdexcept> | |
#include <memory> | |
#include <fstream> | |
#include <iterator> | |
#include <algorithm> | |
#include <tvm/runtime/module.h> | |
#include <tvm/runtime/registry.h> | |
#include <tvm/runtime/packed_func.h> | |
#include "lodepng.h" | |
struct pixel_t { | |
uint8_t r; | |
uint8_t g; | |
uint8_t b; | |
uint8_t a; | |
float get_gray() const { | |
return float(r + g + b) / float(256 + 256 + 256); | |
} | |
}; | |
char chars[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ "; | |
constexpr size_t size_of_chars = sizeof(chars) - 1; | |
#pragma pack(push, 1) | |
struct prediction_t { | |
float probability[size_of_chars]; | |
char get_char() const { | |
auto it = std::max_element(std::begin(probability), std::end(probability)); | |
auto pos = it - std::begin(probability); | |
return chars[pos]; | |
} | |
}; | |
#pragma pack(pop) | |
int main(int argc, char* argv[]) { | |
try { | |
// load dll | |
tvm::runtime::Module mod_lib = tvm::runtime::Module::LoadFromFile("net.dll"); | |
// load json | |
std::ifstream json_in("net.json", std::ios::in); | |
std::string json_data((std::istreambuf_iterator<char>(json_in)), std::istreambuf_iterator<char>()); | |
json_in.close(); | |
// load params | |
std::ifstream params_in("net.params", std::ios::binary); | |
std::string params_data((std::istreambuf_iterator<char>(params_in)), std::istreambuf_iterator<char>()); | |
params_in.close(); | |
TVMByteArray params_arr; | |
params_arr.data = params_data.c_str(); | |
params_arr.size = params_data.length(); | |
// load data | |
std::vector<uint8_t> image; | |
unsigned width, height; | |
unsigned error = lodepng::decode(image, width, height, "1.png"); | |
unsigned channels = image.size() / (width * height); | |
pixel_t* pixels = (pixel_t*)image.data(); | |
std::cout << width << "x" << height << "x" << channels << std::endl; | |
if (error) { | |
throw std::runtime_error(lodepng_error_text(error)); | |
} | |
unsigned char_width = 8; | |
unsigned char_height = 16; | |
unsigned char_size = char_width * char_height; | |
unsigned chars_per_row = width / char_width; | |
unsigned chars_per_column = height / char_height; | |
unsigned chars_count = chars_per_column * chars_per_row; | |
int dtype_code = kDLFloat; | |
int dtype_bits = 32; | |
int dtype_lanes = 1; | |
int device_type = kDLCPU; | |
int device_id = 0; | |
DLTensor* in; | |
int in_ndim = 4; | |
int64_t in_shape[4] = {chars_count, 1, char_height, char_width}; | |
std::cout | |
<< in_shape[0] << "x" | |
<< in_shape[1] << "x" | |
<< in_shape[2] << "x" | |
<< in_shape[3] << std::endl; | |
TVMArrayAlloc(in_shape, in_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &in); | |
memset(in->data, 0, in_shape[0] * in_shape[1] * in_shape[2] * in_shape[3]); | |
DLTensor* out; | |
int out_ndim = 2; | |
int64_t out_shape[2] = {chars_count, size_of_chars}; | |
TVMArrayAlloc(out_shape, out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &out); | |
memset(out->data, 0, out_shape[0] * out_shape[1]); | |
float* in_data = (float*)in->data; | |
prediction_t* predictions = (prediction_t*)out->data; | |
// create graph | |
auto tvm_graph_runtime_create = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); | |
tvm::runtime::Module mod = (*tvm_graph_runtime_create)(json_data, mod_lib, device_type, device_id); | |
tvm::runtime::PackedFunc set_input = mod.GetFunction("set_input"); | |
tvm::runtime::PackedFunc get_output = mod.GetFunction("get_output"); | |
tvm::runtime::PackedFunc load_params = mod.GetFunction("load_params"); | |
tvm::runtime::PackedFunc run = mod.GetFunction("run"); | |
load_params(params_arr); | |
std::string result(chars_count + chars_per_column, '\0'); | |
auto start = std::chrono::high_resolution_clock::now(); | |
size_t index = 0; | |
for (size_t row = 0; row < chars_per_column; ++row) { | |
for (size_t column = 0; column < chars_per_row; ++ column) { | |
for (size_t y = 0; y < char_height; ++y) { | |
for (size_t x = 0; x < char_width; ++x) { | |
in_data[index * char_size + y * char_width + x] = pixels[(row * char_height + y) * width + column * char_width + x].get_gray(); | |
} | |
} | |
++index; | |
} | |
} | |
set_input("data", in); | |
run(); | |
get_output(0, out); | |
size_t i = 0; | |
size_t j = 0; | |
for (size_t row = 0; row < chars_per_column; ++row) { | |
for (size_t column = 0; column < chars_per_row; ++ column) { | |
result[i++] = predictions[j++].get_char(); | |
} | |
result[i++] = '\n'; | |
} | |
auto end = std::chrono::high_resolution_clock::now(); | |
std::cout << result << std::endl; | |
std::chrono::duration<double> diff = end-start; | |
std::cout << diff.count() << std::endl; | |
TVMArrayFree(in); | |
TVMArrayFree(out); | |
} | |
catch (const std::exception& error) { | |
std::cout << error.what() << std::endl; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment