Skip to content

Instantly share code, notes, and snippets.

@mdlockyer
Last active March 30, 2019 00:08
Show Gist options
  • Save mdlockyer/3ff43f00ad7b7e2c2a3a7f33469658da to your computer and use it in GitHub Desktop.
Save mdlockyer/3ff43f00ad7b7e2c2a3a7f33469658da to your computer and use it in GitHub Desktop.
import torch
from torch.nn import Module, Conv2d, Sequential, MSELoss
from torch.optim import SGD
from memory_profiler import profile
import gc
class Model(Module):
def __init__(self):
super().__init__()
model = [Conv2d(3, 512, 3, padding=1)]
for i in range(32):
model += [Conv2d(512, 512, 3, padding=1)]
model += [Conv2d(512, 1, 1)]
self.model = Sequential(*model)
def forward(self, x):
return self.model(x)
@profile
def train(model, criterion, optim):
x = torch.rand(1, 3, 8, 8)
y = torch.ones(1, 1, 8, 8)
out = model(x)
loss = criterion(out, y)
optim.zero_grad()
loss.backward()
optim.step()
optim.zero_grad()
del x, y, out, loss
gc.collect()
def main():
model = Model()
criterion = MSELoss()
optim = SGD(model.parameters(), lr=0.001, momentum=0)
for _ in range(5):
train(model, criterion, optim)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment