Skip to content

Instantly share code, notes, and snippets.

@zplizzi
Last active March 21, 2019 23:20
Show Gist options
  • Save zplizzi/3d8420152fd1eb5b815d9c78d4800fd6 to your computer and use it in GitHub Desktop.
Save zplizzi/3d8420152fd1eb5b815d9c78d4800fd6 to your computer and use it in GitHub Desktop.
from torch import nn
import torch
import torchvision
from torchvision import transforms
from apex import amp
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(3, 10, kernel_size=3)
self.bn = nn.BatchNorm2d(10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
model = Model().cuda()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
CIFAR10_DIR = "YOUR_DIR_HERE"
device = torch.device("cuda:0")
data = torchvision.datasets.CIFAR10(CIFAR10_DIR, train=True, download=True,
transform=transform)
dataloader = iter(torch.utils.data.DataLoader(data, batch_size=64))
opt = torch.optim.RMSprop(model.parameters(), lr=.001)
model, opt = amp.initialize(model, opt, opt_level = "O1")
while True:
x, _ = next(dataloader)
x.requires_grad = True
x = x.to(device)
model.zero_grad()
y = model(x)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=y.new_ones(y.size()),
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
penalty = (gradients.norm(2, dim=1) ** 2).mean()
with amp.scale_loss(penalty, opt) as scaled_loss:
scaled_loss.backward()
opt.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment