Skip to content

Instantly share code, notes, and snippets.

@skeeet
Forked from zeryx/CMakeLists.txt
Created April 1, 2019 22:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save skeeet/3cd6cefc233706a6c62afe16c776b6b5 to your computer and use it in GitHub Desktop.
Save skeeet/3cd6cefc233706a6c62afe16c776b6b5 to your computer and use it in GitHub Desktop.
minimal pytorch 1.0 pytorch -> C++ full example demo image at: https://i.imgur.com/hiWRITj.jpg
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(cpp_shim)
set(CMAKE_PREFIX_PATH ../libtorch)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
add_executable(testing main.cpp)
message(STATUS "OpenCV library status:")
message(STATUS " config: ${OpenCV_DIR}")
message(STATUS " version: ${OpenCV_VERSION}")
message(STATUS " libraries: ${OpenCV_LIBS}")
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
message(STATUS "TORCHLIB: ${TORCH_LIBRARIES}")
#target_include_directories(testing PRIVATE ${TORCH_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
target_link_libraries(testing ${OpenCV_LIBS})
target_link_libraries(testing ${TORCH_LIBRARIES})
target_compile_definitions(testing PRIVATE -D_GLIBCXX_USE_CXX11_ABI=0)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace
class MyScriptModule(ScriptModule):
# class MyScriptModule(nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(3, 2, 5).to("cpu"), torch.rand(1, 3, 1266, 1900))
self.conv2 = trace(nn.Conv2d(2, 1, 5).to("cpu"), torch.rand(1, 2, 1266, 1900))
self.lin = trace(nn.Linear(1258*1892, 5), torch.rand(1258*1892))
@script_method
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
input = input.squeeze()
input = input.view(1258*1892)
output = self.lin(input)
return output
test_module = MyScriptModule()
print(test_module.graph)
if __name__ == "__main__":
test_module.save("/tmp/model.pl")
# if __name__ == "__main__":
# import numpy as np
# from PIL import Image
# img_path = "/tmp/cat_image.jpg"
# img = np.asarray(Image.open(img_path))
# tensor = torch.from_numpy(img).float()
# tensor = tensor.view(1, 3, tensor.shape[0], tensor.shape[1])
# test_module.forward(tensor)
//
// Created by zeryx on 10/5/18.
//
#include <torch/script.h>
#include <iostream>
#include <memory>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
using namespace cv;
int main() {
std::string model_path = "/tmp/model.pl";
std::string image_path = "/tmp/cat_image.jpg";
Mat image = imread(image_path);
std::vector<int64_t> sizes = {1, 3, image.rows, image.cols};
at::TensorOptions options(at::ScalarType::Byte);
at::Tensor tensor_image = torch::from_blob(image.data, at::IntList(sizes), options);
tensor_image = tensor_image.toType(at::kFloat);
std::ifstream is (model_path, std::ifstream::binary);
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(is);
std::vector<torch::jit::IValue> inputs;
inputs.emplace_back(tensor_image);
at::Tensor result = module->forward(inputs).toTensor();
auto max_result = result.max(0, true);
auto max_index = std::get<1>(max_result).item<float>();
std::cout << max_index << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment