Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created October 6, 2019 19:01
Show Gist options
  • Save ptrblck/354c895c08ba5fb25cf7f9a6e74d4779 to your computer and use it in GitHub Desktop.
Save ptrblck/354c895c08ba5fb25cf7f9a6e74d4779 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torchvision.models as models
import time
# Create dummy data
data = torch.randn(1, 3, 224, 224, device='cuda')
target = torch.randint(0, 100, (1,), device='cuda')
model = models.resnet152()
model.fc = nn.Linear(in_features=2048, out_features=100)
model.cuda()
# Train whole model
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
nb_epochs = 10
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_epochs):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
t1 = time.time()
print('full update took {}s per epoch'.format((t1-t0)/nb_epochs))
# Only train last layer
for param in model.parameters():
param.requires_grad_(False)
model.fc.weight.requires_grad_(True)
model.fc.bias.requires_grad_(True)
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_epochs):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
t1 = time.time()
print('frozen update took {}s per epoch'.format((t1-t0)/nb_epochs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment