Skip to content

Instantly share code, notes, and snippets.

@log0div0
Created August 8, 2018 07:26
Show Gist options
  • Save log0div0/f1d343da3c0dec9c44277c256ae2f0f7 to your computer and use it in GitHub Desktop.
Save log0div0/f1d343da3c0dec9c44277c256ae2f0f7 to your computer and use it in GitHub Desktop.
#include <iostream>
#include <chrono>
#include <stdexcept>
#include <memory>
#include <fstream>
#include <iterator>
#include <algorithm>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include "lodepng.h"
struct pixel_t {
uint8_t r;
uint8_t g;
uint8_t b;
uint8_t a;
float get_gray() const {
return float(r + g + b) / float(256 + 256 + 256);
}
};
char chars[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ ";
constexpr size_t size_of_chars = sizeof(chars) - 1;
#pragma pack(push, 1)
struct prediction_t {
float probability[size_of_chars];
char get_char() const {
auto it = std::max_element(std::begin(probability), std::end(probability));
auto pos = it - std::begin(probability);
return chars[pos];
}
};
#pragma pack(pop)
int main(int argc, char* argv[]) {
try {
// load dll
tvm::runtime::Module mod_lib = tvm::runtime::Module::LoadFromFile("net.dll");
// load json
std::ifstream json_in("net.json", std::ios::in);
std::string json_data((std::istreambuf_iterator<char>(json_in)), std::istreambuf_iterator<char>());
json_in.close();
// load params
std::ifstream params_in("net.params", std::ios::binary);
std::string params_data((std::istreambuf_iterator<char>(params_in)), std::istreambuf_iterator<char>());
params_in.close();
TVMByteArray params_arr;
params_arr.data = params_data.c_str();
params_arr.size = params_data.length();
// load data
std::vector<uint8_t> image;
unsigned width, height;
unsigned error = lodepng::decode(image, width, height, "1.png");
unsigned channels = image.size() / (width * height);
pixel_t* pixels = (pixel_t*)image.data();
std::cout << width << "x" << height << "x" << channels << std::endl;
if (error) {
throw std::runtime_error(lodepng_error_text(error));
}
unsigned char_width = 8;
unsigned char_height = 16;
unsigned char_size = char_width * char_height;
unsigned chars_per_row = width / char_width;
unsigned chars_per_column = height / char_height;
unsigned chars_count = chars_per_column * chars_per_row;
int dtype_code = kDLFloat;
int dtype_bits = 32;
int dtype_lanes = 1;
int device_type = kDLCPU;
int device_id = 0;
DLTensor* in;
int in_ndim = 4;
int64_t in_shape[4] = {chars_count, 1, char_height, char_width};
std::cout
<< in_shape[0] << "x"
<< in_shape[1] << "x"
<< in_shape[2] << "x"
<< in_shape[3] << std::endl;
TVMArrayAlloc(in_shape, in_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &in);
memset(in->data, 0, in_shape[0] * in_shape[1] * in_shape[2] * in_shape[3]);
DLTensor* out;
int out_ndim = 2;
int64_t out_shape[2] = {chars_count, size_of_chars};
TVMArrayAlloc(out_shape, out_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &out);
memset(out->data, 0, out_shape[0] * out_shape[1]);
float* in_data = (float*)in->data;
prediction_t* predictions = (prediction_t*)out->data;
// create graph
auto tvm_graph_runtime_create = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
tvm::runtime::Module mod = (*tvm_graph_runtime_create)(json_data, mod_lib, device_type, device_id);
tvm::runtime::PackedFunc set_input = mod.GetFunction("set_input");
tvm::runtime::PackedFunc get_output = mod.GetFunction("get_output");
tvm::runtime::PackedFunc load_params = mod.GetFunction("load_params");
tvm::runtime::PackedFunc run = mod.GetFunction("run");
load_params(params_arr);
std::string result(chars_count + chars_per_column, '\0');
auto start = std::chrono::high_resolution_clock::now();
size_t index = 0;
for (size_t row = 0; row < chars_per_column; ++row) {
for (size_t column = 0; column < chars_per_row; ++ column) {
for (size_t y = 0; y < char_height; ++y) {
for (size_t x = 0; x < char_width; ++x) {
in_data[index * char_size + y * char_width + x] = pixels[(row * char_height + y) * width + column * char_width + x].get_gray();
}
}
++index;
}
}
set_input("data", in);
run();
get_output(0, out);
size_t i = 0;
size_t j = 0;
for (size_t row = 0; row < chars_per_column; ++row) {
for (size_t column = 0; column < chars_per_row; ++ column) {
result[i++] = predictions[j++].get_char();
}
result[i++] = '\n';
}
auto end = std::chrono::high_resolution_clock::now();
std::cout << result << std::endl;
std::chrono::duration<double> diff = end-start;
std::cout << diff.count() << std::endl;
TVMArrayFree(in);
TVMArrayFree(out);
}
catch (const std::exception& error) {
std::cout << error.what() << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment