Skip to content

Instantly share code, notes, and snippets.

@mdlockyer
Last active March 30, 2019 00:07
Show Gist options
  • Save mdlockyer/1b728751113067c47ef104a5ecf1691d to your computer and use it in GitHub Desktop.
Save mdlockyer/1b728751113067c47ef104a5ecf1691d to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Module, Conv2d, Sequential, MSELoss
from torch.optim import SGD
import gc
class Model(Module):
def __init__(self):
super().__init__()
model = [Conv2d(3, 512, 3, padding=1)]
for i in range(100):
model += [Conv2d(512, 512, 3, padding=1)]
model += [Conv2d(512, 1, 1)]
self.model = Sequential(*model)
def forward(self, x):
return self.model(x)
def train(model, criterion, optim, device):
x = torch.rand(1, 3, 8, 8, device=device)
y = torch.ones(1, 1, 8, 8, device=device)
out = model(x)
loss = criterion(out, y)
optim.zero_grad()
loss.backward()
optim.step()
optim.zero_grad()
del x, y, out, loss
gc.collect()
print('Max memory allocated: {0:.2f} MB'
.format(torch.cuda.max_memory_allocated() / 1e6))
print('Max memory cached: {0:.2f} MB'
.format(torch.cuda.max_memory_cached() / 1e6))
def main():
device = torch.device('cuda:0')
model = Model().to(device)
criterion = MSELoss()
optim = SGD(model.parameters(), lr=0.001, momentum=0)
for _ in range(5):
train(model, criterion, optim, device)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment