Skip to content

Instantly share code, notes, and snippets.

@driazati
Created August 26, 2019 17:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save driazati/a4c26bf6660e97c066b2cd1faba295cd to your computer and use it in GitHub Desktop.
Save driazati/a4c26bf6660e97c066b2cd1faba295cd to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
from typing import Dict, Optional
import torchvision
class Model(nn.Module):
def __init__(self, num_cats):
super(Model, self).__init__()
self.model = torchvision.models.resnet18(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, num_cats)
self.model = torch.jit.trace(self.model, torch.rand(1,3,224,224))
def forward(self, x):
x = self.model(x)
return x
original = Model(2)
scripted = torch.jit.script(original)
torch.jit.save(scripted, 'scripted_model.pth')
loaded = torch.jit.load('scripted_model.pth')
while True:
input = torch.randn(1, 3, 224, 224)
out1 = loaded(input)
out2 = scripted(input)
out3 = original(input)
print(out1.allclose(out2))
print(out2.allclose(out3))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment