Skip to content

Instantly share code, notes, and snippets.

@alannnna
Last active March 3, 2021 17:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alannnna/008106cdee4ced848780ed1d829f57d2 to your computer and use it in GitHub Desktop.
Save alannnna/008106cdee4ced848780ed1d829f57d2 to your computer and use it in GitHub Desktop.
Pytorch JIT scripting
import time
import torch
from mnist import Net, DEFAULTS, train_example
def make_net():
torch.manual_seed(1234)
model = Net()
optimizer = torch.optim.Adadelta(model.parameters(), lr=DEFAULTS.lr)
return model, optimizer
def try_jit():
model, optimizer = make_net()
data = torch.rand(DEFAULTS.batch_size, 1, 28, 28)
labels = torch.randint(low=0, high=10, size=(DEFAULTS.batch_size,))
# non traced
start = time.time()
for i in range(100):
train_example(optimizer, model, data, labels)
end = time.time()
print(end - start)
# scripted
scripted_train_example = torch.jit.script(train_example)
start = time.time()
for i in range(100):
scripted_train_example(optimizer, model, data, labels)
end = time.time()
print(end - start)
if __name__ == "__main__":
try_jit()
$ python bench.py
7.002076864242554
Traceback (most recent call last):
File "bench.py", line 50, in <module>
try_jit()
File "bench.py", line 31, in try_jit
scripted_train_example = torch.jit.script(train_example)
File "/Users/me/blah/blah/venv/lib/python3.7/site-packages/torch/jit/_script.py", line 940, in script
qualified_name, ast, _rcb, get_default_args(obj)
RuntimeError:
Tried to access nonexistent attribute or method 'zero_grad' of type 'Tensor (inferred)'.:
File "/Users/me/blah/blah/mnist.py", line 66
def train_example(optimizer, model, data, target):
optimizer.zero_grad()
~~~~~~~~~~~~~~~~~~~ <--- HERE
loss = model.get_loss_and_do_backward(data, target)
optimizer.step()
from __future__ import print_function
import argparse
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class Defaults:
batch_size = 300
test_batch_size = 1000
epochs = 3
lr = 0.035
seed = 1
log_interval = 10
save_model = True
use_jit = False
DEFAULTS = Defaults()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, padding = 2)
self.conv2 = nn.Conv2d(6, 16, 5, padding = 2)
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
output = F.log_softmax(x, dim=1)
return output
def get_loss_and_do_backward(self, data, target):
output = self(data)
loss = F.nll_loss(output, target)
loss.backward()
return loss
def train_example(optimizer, model, data, target):
optimizer.zero_grad()
loss = model.get_loss_and_do_backward(data, target)
optimizer.step()
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment