Skip to content

Instantly share code, notes, and snippets.

@dnlcrl
Last active August 27, 2018 08:14
Show Gist options
  • Save dnlcrl/5679d340703cc09583158383ea87a22a to your computer and use it in GitHub Desktop.
Save dnlcrl/5679d340703cc09583158383ea87a22a to your computer and use it in GitHub Desktop.
PyTorch Model Export (Python) Import (Python, C++) Snippets
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(1,1)
def forward(self, x):
y = self.linear(x)
return y
x = torch.zeros(1)
model = Model()
module = torch.jit.trace(x)(model)
module.save('trace.pt')
#include "torch/csrc/jit/import.h"
#include <ATen/ATen.h>
using namespace torch;
int main(int argc, char const *argv[])
{
std::shared_ptr<jit::script::Module> module = jit::load("trace.pt");
at::Tensor x = at::zeros({1});
jit::Stack stack;
stack.push_back(autograd::make_variable(x));
module->get_method("forward").run(stack);
at::Tensor y = torch::autograd::Variable(stack[0].toTensor()).data();
return 0;
}
import torch
module = torch.jit.load('trace.pt')
x = torch.zeros(1)
with torch.no_grad():
y = module(x)
@lantiga
Copy link

lantiga commented Aug 11, 2018

Notes:

  1. what gets saved is a PyTorch script module (technically it's more than just a trace); we should call the variable being imported module or something like it
  2. when running the imported model, for now we should use the with torch.no_grad(): context, otherwise printing the resulting tensor will fail because PyTorch can't deal with grad_fn (I'm opening an issue)

@dnlcrl
Copy link
Author

dnlcrl commented Aug 12, 2018

fixed ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment