Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created October 11, 2018 17:21
Show Gist options
  • Star 9 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save zeryx/526dbc05479e166ca7d512a670e6b82d to your computer and use it in GitHub Desktop.
Save zeryx/526dbc05479e166ca7d512a670e6b82d to your computer and use it in GitHub Desktop.
minimal pytorch 1.0 pytorch -> C++ full example demo image at: https://i.imgur.com/hiWRITj.jpg
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(cpp_shim)
set(CMAKE_PREFIX_PATH ../libtorch)
find_package(Torch REQUIRED)
find_package(OpenCV REQUIRED)
add_executable(testing main.cpp)
message(STATUS "OpenCV library status:")
message(STATUS " config: ${OpenCV_DIR}")
message(STATUS " version: ${OpenCV_VERSION}")
message(STATUS " libraries: ${OpenCV_LIBS}")
message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")
message(STATUS "TORCHLIB: ${TORCH_LIBRARIES}")
#target_include_directories(testing PRIVATE ${TORCH_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
target_link_libraries(testing ${OpenCV_LIBS})
target_link_libraries(testing ${TORCH_LIBRARIES})
target_compile_definitions(testing PRIVATE -D_GLIBCXX_USE_CXX11_ABI=0)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace
class MyScriptModule(ScriptModule):
# class MyScriptModule(nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(3, 2, 5).to("cpu"), torch.rand(1, 3, 1266, 1900))
self.conv2 = trace(nn.Conv2d(2, 1, 5).to("cpu"), torch.rand(1, 2, 1266, 1900))
self.lin = trace(nn.Linear(1258*1892, 5), torch.rand(1258*1892))
@script_method
def forward(self, input):
input = F.relu(self.conv1(input))
input = F.relu(self.conv2(input))
input = input.squeeze()
input = input.view(1258*1892)
output = self.lin(input)
return output
test_module = MyScriptModule()
print(test_module.graph)
if __name__ == "__main__":
test_module.save("/tmp/model.pl")
# if __name__ == "__main__":
# import numpy as np
# from PIL import Image
# img_path = "/tmp/cat_image.jpg"
# img = np.asarray(Image.open(img_path))
# tensor = torch.from_numpy(img).float()
# tensor = tensor.view(1, 3, tensor.shape[0], tensor.shape[1])
# test_module.forward(tensor)
//
// Created by zeryx on 10/5/18.
//
#include <torch/script.h>
#include <iostream>
#include <memory>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
using namespace cv;
int main() {
std::string model_path = "/tmp/model.pl";
std::string image_path = "/tmp/cat_image.jpg";
Mat image = imread(image_path);
std::vector<int64_t> sizes = {1, 3, image.rows, image.cols};
at::TensorOptions options(at::ScalarType::Byte);
at::Tensor tensor_image = torch::from_blob(image.data, at::IntList(sizes), options);
tensor_image = tensor_image.toType(at::kFloat);
std::ifstream is (model_path, std::ifstream::binary);
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(is);
std::vector<torch::jit::IValue> inputs;
inputs.emplace_back(tensor_image);
at::Tensor result = module->forward(inputs).toTensor();
auto max_result = result.max(0, true);
auto max_index = std::get<1>(max_result).item<float>();
std::cout << max_index << std::endl;
}
@Arnold1
Copy link

Arnold1 commented Dec 4, 2018

@zeryx thanks for providing your example.

  • however i get this error with generator.py
$ python3 generator.py 
Traceback (most recent call last):
  File "generator.py", line 26, in <module>
    test_module = MyScriptModule()
  File "/Users/geri/.pyenv/versions/3.6.0/lib/python3.6/site-packages/torch/jit/__init__.py", line 594, in init_then_register
    self._create_methods(asts, rcbs)
RuntimeError: 
arguments for call are not valid:
  
  for operator aten::view(Tensor self, int[] size):
  expected a value of type int[] for argument 'size' but found int
  @script_method
  def forward(self, input):
    input = F.relu(self.conv1(input))
    input = F.relu(self.conv2(input))
    input = input.squeeze()
    input = input.view(1258*1892)
                       ~~~~~~~~~ <--- HERE
    output = self.lin(input)
    return output
for call at:
@script_method
def forward(self, input):
  input = F.relu(self.conv1(input))
  input = F.relu(self.conv2(input))
  input = input.squeeze()
  input = input.view(1258*1892)
          ~~~~~~~~~~ <--- HERE
  output = self.lin(input)
  return output

i installed pytorch:

pip3 install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
Looking in links: https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
Collecting torch_nightly
  Downloading https://download.pytorch.org/whl/nightly/cpu/torch_nightly-1.0.0.dev20181203-cp36-none-macosx_10_7_x86_64.whl (61.3MB)
    100% |████████████████████████████████| 61.3MB 33kB/s 
Installing collected packages: torch-nightly
Successfully installed torch-nightly-1.0.0.dev20181203

i installed libtorch following these steps:
https://pytorch.org/cppdocs/installing.html
but i downloaded the libtorch from here:
https://download.pytorch.org/libtorch/nightly/cpu/libtorch-macos-latest.zip

instead:
https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip

@Arnold1
Copy link

Arnold1 commented Dec 5, 2018

generator.py works now, here the output:

$ python3 generator.py
graph(%input.1 : Tensor
      %1 : Tensor
      %2 : Tensor
      %25 : Tensor
      %26 : Tensor
      %55 : Tensor
      %56 : Tensor) {
  %19 : bool = prim::Constant[value=1](), scope: Conv2d
  %12 : bool = prim::Constant[value=0](), scope: Conv2d
  %6 : int = prim::Constant[value=0](), scope: Conv2d
  %3 : int = prim::Constant[value=1](), scope: Conv2d
  %50 : int = prim::Constant[value=1258]()
  %51 : int = prim::Constant[value=1892]()
  %5 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %8 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %11 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %15 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %20 : Float(1, 2, 1262, 1896) = aten::_convolution(%input.1, %1, %2, %5, %8, %11, %12, %15, %3, %12, %12, %19), scope: Conv2d
  %result.3 : Tensor = prim::If(%12)
    block0() {
      %result.4 : Tensor = aten::relu_(%20)
      -> (%result.4)
    }
    block1() {
      %result.5 : Tensor = aten::relu(%20)
      -> (%result.5)
    }
  %29 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %32 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %35 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %39 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %44 : Float(1, 1, 1262, 1896) = aten::_convolution(%result.3, %25, %26, %29, %32, %35, %12, %39, %3, %12, %12, %19), scope: Conv2d
  %result : Tensor = prim::If(%12)
    block0() {
      %result.1 : Tensor = aten::relu_(%44)
      -> (%result.1)
    }
    block1() {
      %result.2 : Tensor = aten::relu(%44)
      -> (%result.2)
    }
  %input.2 : Tensor = aten::squeeze(%result)
  %52 : int = aten::mul(%50, %51)
  %53 : int[] = prim::ListConstruct(%52)
  %input : Tensor = aten::view(%input.2, %53)
  %57 : Float(2380136!, 5!) = aten::t(%55), scope: Linear
  %output.1 : Float(5) = aten::matmul(%input, %57), scope: Linear
  %output : Float(5) = aten::add_(%output.1, %56, %3), scope: Linear
  return (%output);
}

@zeryx when i call the binary and try to load the model i get, why isnt it able to load the model?

$ make clean; make; ./testing
Scanning dependencies of target testing
[ 50%] Building CXX object CMakeFiles/testing.dir/main.cpp.o
[100%] Linking CXX executable testing
[100%] Built target testing
libc++abi.dylib: terminating with uncaught exception of type c10::Error: INVALID_ARGUMENT:tensors[5]: Cannot find field. (deserialize at /Users/administrator/nightlies/2018_12_03/wheel_build_dirs/libtorch_2.7/pytorch/torch/csrc/jit/import.cpp:92)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x11351df87 in libc10.dylib)
frame #1: torch::jit::(anonymous namespace)::ScriptModuleDeserializer::deserialize(std::__1::function<std::__1::shared_ptr<torch::jit::script::Module> (std::__1::vector<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::allocator<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > > > const&)>) + 6408 (0x10ca94c78 in libtorch.1.dylib)
frame #2: torch::jit::load(std::__1::basic_istream<char, std::__1::char_traits<char> >&) + 538 (0x10ca95d4a in libtorch.1.dylib)
frame #3: main + 2188 (0x10b61f13c in testing)
frame #4: start + 1 (0x7fff8a8605ad in libdyld.dylib)

Abort trap: 6

@LvJC
Copy link

LvJC commented Dec 26, 2018

generator.py works now, here the output:

$ python3 generator.py
graph(%input.1 : Tensor
      %1 : Tensor
      %2 : Tensor
      %25 : Tensor
      %26 : Tensor
      %55 : Tensor
      %56 : Tensor) {
  %19 : bool = prim::Constant[value=1](), scope: Conv2d
  %12 : bool = prim::Constant[value=0](), scope: Conv2d
  %6 : int = prim::Constant[value=0](), scope: Conv2d
  %3 : int = prim::Constant[value=1](), scope: Conv2d
  %50 : int = prim::Constant[value=1258]()
  %51 : int = prim::Constant[value=1892]()
  %5 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %8 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %11 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %15 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %20 : Float(1, 2, 1262, 1896) = aten::_convolution(%input.1, %1, %2, %5, %8, %11, %12, %15, %3, %12, %12, %19), scope: Conv2d
  %result.3 : Tensor = prim::If(%12)
    block0() {
      %result.4 : Tensor = aten::relu_(%20)
      -> (%result.4)
    }
    block1() {
      %result.5 : Tensor = aten::relu(%20)
      -> (%result.5)
    }
  %29 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %32 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %35 : int[] = prim::ListConstruct(%3, %3), scope: Conv2d
  %39 : int[] = prim::ListConstruct(%6, %6), scope: Conv2d
  %44 : Float(1, 1, 1262, 1896) = aten::_convolution(%result.3, %25, %26, %29, %32, %35, %12, %39, %3, %12, %12, %19), scope: Conv2d
  %result : Tensor = prim::If(%12)
    block0() {
      %result.1 : Tensor = aten::relu_(%44)
      -> (%result.1)
    }
    block1() {
      %result.2 : Tensor = aten::relu(%44)
      -> (%result.2)
    }
  %input.2 : Tensor = aten::squeeze(%result)
  %52 : int = aten::mul(%50, %51)
  %53 : int[] = prim::ListConstruct(%52)
  %input : Tensor = aten::view(%input.2, %53)
  %57 : Float(2380136!, 5!) = aten::t(%55), scope: Linear
  %output.1 : Float(5) = aten::matmul(%input, %57), scope: Linear
  %output : Float(5) = aten::add_(%output.1, %56, %3), scope: Linear
  return (%output);
}

@zeryx when i call the binary and try to load the model i get, why isnt it able to load the model?

$ make clean; make; ./testing
Scanning dependencies of target testing
[ 50%] Building CXX object CMakeFiles/testing.dir/main.cpp.o
[100%] Linking CXX executable testing
[100%] Built target testing
libc++abi.dylib: terminating with uncaught exception of type c10::Error: INVALID_ARGUMENT:tensors[5]: Cannot find field. (deserialize at /Users/administrator/nightlies/2018_12_03/wheel_build_dirs/libtorch_2.7/pytorch/torch/csrc/jit/import.cpp:92)
frame #0: c10::Error::Error(c10::SourceLocation, std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&) + 135 (0x11351df87 in libc10.dylib)
frame #1: torch::jit::(anonymous namespace)::ScriptModuleDeserializer::deserialize(std::__1::function<std::__1::shared_ptr<torch::jit::script::Module> (std::__1::vector<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> >, std::__1::allocator<std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > > > const&)>) + 6408 (0x10ca94c78 in libtorch.1.dylib)
frame #2: torch::jit::load(std::__1::basic_istream<char, std::__1::char_traits<char> >&) + 538 (0x10ca95d4a in libtorch.1.dylib)
frame #3: main + 2188 (0x10b61f13c in testing)
frame #4: start + 1 (0x7fff8a8605ad in libdyld.dylib)

Abort trap: 6

This is due to the different CUDA version between LibTorch and PyTorch. For example, maybe your PyTorch is under CUDA10 but LibTorch is under CUDA9.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment