Skip to content

Instantly share code, notes, and snippets.

@Unbinilium
Last active March 2, 2022 05:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Unbinilium/d9de5b2d544ad5fe222b0133981687ae to your computer and use it in GitHub Desktop.
Save Unbinilium/d9de5b2d544ad5fe222b0133981687ae to your computer and use it in GitHub Desktop.
Inferring MNIST Torchscript model
#pragma once
#include <string>
#include <vector>
#include <utility>
#include <torch/torch.h>
#include <torch/script.h>
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
namespace dnn {
template <typename T>
class ts_mnist {
public:
ts_mnist(
const std::string& path,
const std::vector<T>& labels
) : _path(path), _labels(labels) {
_module = torch::jit::load(_path);
_inputs.resize(1);
}
auto inferring(const cv::Mat& image) noexcept {
cv::cvtColor(image, _gray, cv::COLOR_BGR2GRAY);
cv::resize(_gray, _gray, cv::Size(28, 28));
_tensor_image = torch::from_blob(_gray.data, { _gray.rows, _gray.cols }, torch::kUInt8);
_tensor_image_normed = (_tensor_image / 255.f).sub_(0.5f).div_(0.5f);
_inputs[0] = _tensor_image_normed.unsqueeze_(0).unsqueeze_(0);
_output = _module.forward(_inputs).toTensor();
_index = _output.argmax().item<int32_t>();
return std::pair<T, torch::Tensor>(_labels.at(_index), _output.index({0, _index}));
}
private:
const std::string _path;
const std::vector<T> _labels;
torch::jit::script::Module _module;
torch::Tensor _tensor_image;
torch::Tensor _tensor_image_normed;
torch::Tensor _output;
int32_t _index;
std::vector<torch::jit::IValue> _inputs;
cv::Mat _gray;
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment