-
-
Save alannnna/008106cdee4ced848780ed1d829f57d2 to your computer and use it in GitHub Desktop.
Pytorch JIT scripting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
from mnist import Net, DEFAULTS, train_example | |
def make_net(): | |
torch.manual_seed(1234) | |
model = Net() | |
optimizer = torch.optim.Adadelta(model.parameters(), lr=DEFAULTS.lr) | |
return model, optimizer | |
def try_jit(): | |
model, optimizer = make_net() | |
data = torch.rand(DEFAULTS.batch_size, 1, 28, 28) | |
labels = torch.randint(low=0, high=10, size=(DEFAULTS.batch_size,)) | |
# non traced | |
start = time.time() | |
for i in range(100): | |
train_example(optimizer, model, data, labels) | |
end = time.time() | |
print(end - start) | |
# scripted | |
scripted_train_example = torch.jit.script(train_example) | |
start = time.time() | |
for i in range(100): | |
scripted_train_example(optimizer, model, data, labels) | |
end = time.time() | |
print(end - start) | |
if __name__ == "__main__": | |
try_jit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python bench.py | |
7.002076864242554 | |
Traceback (most recent call last): | |
File "bench.py", line 50, in <module> | |
try_jit() | |
File "bench.py", line 31, in try_jit | |
scripted_train_example = torch.jit.script(train_example) | |
File "/Users/me/blah/blah/venv/lib/python3.7/site-packages/torch/jit/_script.py", line 940, in script | |
qualified_name, ast, _rcb, get_default_args(obj) | |
RuntimeError: | |
Tried to access nonexistent attribute or method 'zero_grad' of type 'Tensor (inferred)'.: | |
File "/Users/me/blah/blah/mnist.py", line 66 | |
def train_example(optimizer, model, data, target): | |
optimizer.zero_grad() | |
~~~~~~~~~~~~~~~~~~~ <--- HERE | |
loss = model.get_loss_and_do_backward(data, target) | |
optimizer.step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import argparse | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
class Defaults: | |
batch_size = 300 | |
test_batch_size = 1000 | |
epochs = 3 | |
lr = 0.035 | |
seed = 1 | |
log_interval = 10 | |
save_model = True | |
use_jit = False | |
DEFAULTS = Defaults() | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 6, 5, padding = 2) | |
self.conv2 = nn.Conv2d(6, 16, 5, padding = 2) | |
self.fc1 = nn.Linear(784, 128) | |
self.fc2 = nn.Linear(128, 84) | |
self.fc3 = nn.Linear(84, 10) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = F.max_pool2d(x, 2) | |
x = self.conv2(x) | |
x = F.relu(x) | |
x = F.max_pool2d(x, 2) | |
x = torch.flatten(x, 1) | |
x = self.fc1(x) | |
x = self.fc2(x) | |
x = self.fc3(x) | |
output = F.log_softmax(x, dim=1) | |
return output | |
def get_loss_and_do_backward(self, data, target): | |
output = self(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
return loss | |
def train_example(optimizer, model, data, target): | |
optimizer.zero_grad() | |
loss = model.get_loss_and_do_backward(data, target) | |
optimizer.step() | |
return loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment