Created
June 11, 2021 02:58
-
-
Save luistung/855397afe554377f74cf56cd9bb547c1 to your computer and use it in GitHub Desktop.
pytorch to c++
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
cmake_minimum_required(VERSION 3.0 FATAL_ERROR) | |
project(custom_ops) | |
find_package(Torch REQUIRED) | |
add_executable(example-app example-app.cpp) | |
target_link_libraries(example-app "${TORCH_LIBRARIES}") | |
set_property(TARGET example-app PROPERTY CXX_STANDARD 14) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <torch/script.h> // One-stop header. | |
#include <iostream> | |
#include <memory> | |
int main(int argc, const char* argv[]) { | |
if (argc != 2) { | |
std::cerr << "usage: example-app <path-to-exported-script-module>\n"; | |
return -1; | |
} | |
torch::jit::script::Module module; | |
try { | |
// Deserialize the ScriptModule from a file using torch::jit::load(). | |
module = torch::jit::load(argv[1]); | |
} | |
catch (const c10::Error& e) { | |
std::cerr << "error loading the model\n"; | |
return -1; | |
} | |
std::cout << "ok\n"; | |
//Create a vector of inputs. | |
std::vector<torch::jit::IValue> inputs; | |
//inputs.push_back(torch::ones({1, 3, 224, 224})); | |
inputs.push_back(torch::tensor({-3., 0., 0.})); | |
// Execute the model and turn its output into a tensor. | |
at::Tensor output = module.forward(inputs).toTensor(); | |
std::cout << "###" << std::endl; | |
std::cout << output << std::endl; | |
std::cout << "###" << std::endl; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class MyModule(torch.nn.Module): | |
def __init__(self, N, M): | |
super(MyModule, self).__init__() | |
self.weight = torch.nn.Parameter(torch.zeros(N, M)) | |
def forward(self, input): | |
print(input.sum()) | |
if input[0] > 0: | |
output = self.weight - input | |
else: | |
output = self.weight @ input | |
return output | |
my_module = MyModule(2,3) | |
sm = torch.jit.script(my_module) | |
#my_module(torch.tensor([-1.,0.,0.])) | |
sm.save("traced_resnet_model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mkdir build | |
cd build | |
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. | |
cmake --build . --config Release |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Converting to Torch Script via Annotation
https://pytorch.org/tutorials/advanced/cpp_export.html#step-1-converting-your-pytorch-model-to-torch-script