Skip to content

Instantly share code, notes, and snippets.

@t-vi
Created March 21, 2017 21:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-vi/60515a24e1cbc4dc87897a6d8c224698 to your computer and use it in GitHub Desktop.
Save t-vi/60515a24e1cbc4dc87897a6d8c224698 to your computer and use it in GitHub Desktop.
Torch bidirectional LSTM memory consumption
#!/usr/bin/python3
import torch
from torch.autograd import Variable
import torch.nn as nn
import gc
# helper function to get rss size, see stat(5) under statm. This is in pages (4k on my linux)
def memory_usage():
return int(open('/proc/self/statm').read().split()[1])
class DummyModule(nn.Module):
def __init__(self, bidirectional=True):
super(DummyModule, self).__init__()
self.bidirectional = bidirectional
self.rnn1 = nn.LSTM(input_size=400, hidden_size=400,
bidirectional=bidirectional, bias=False, batch_first=True)
self.rnn2 = nn.LSTM(input_size=400, hidden_size=400,
bidirectional=bidirectional, bias=False, batch_first=True)
self.rnn3 = nn.LSTM(input_size=400, hidden_size=400,
bidirectional=bidirectional, bias=False, batch_first=True)
self.rnn4 = nn.LSTM(input_size=400, hidden_size=400,
bidirectional=bidirectional, bias=False, batch_first=True)
def forward(self, x):
x,_ = self.rnn1(x)
x = x.contiguous()
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)
x,_ = self.rnn2(x)
x = x.contiguous()
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)
x,_ = self.rnn3(x)
x = x.contiguous()
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)
x,_ = self.rnn4(x)
return x
for bidirectional in [True, False]:
model = DummyModule(bidirectional=bidirectional)
model.cuda()
print (model)
initial_m = None
for i in range(20):
for j in range(100):
inputs = torch.zeros(32, 200, 400)
inputs = inputs.cuda()
inputs = Variable(inputs)
out = model(inputs)
gc.collect()
m = memory_usage()
if initial_m is None:
initial_m = m
print ("at epoch {}: consuming extra {:.1f} MB".format(i,(m-initial_m)/256))
if (m-initial_m)>25600:
print ("consumed more than 100 MB extra after {} loops".format(i))
DummyModule (
(rnn1): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True)
(rnn2): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True)
(rnn3): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True)
(rnn4): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True)
)
at epoch 0: consuming extra 0.0 MB
at epoch 1: consuming extra 1.9 MB
at epoch 2: consuming extra 22.8 MB
at epoch 3: consuming extra 23.1 MB
at epoch 4: consuming extra 23.9 MB
at epoch 5: consuming extra 34.0 MB
at epoch 6: consuming extra 34.3 MB
at epoch 7: consuming extra 46.3 MB
at epoch 8: consuming extra 46.6 MB
at epoch 9: consuming extra 57.0 MB
at epoch 10: consuming extra 67.6 MB
at epoch 11: consuming extra 67.7 MB
at epoch 12: consuming extra 79.0 MB
at epoch 13: consuming extra 90.1 MB
at epoch 14: consuming extra 101.2 MB
consumed more than 100 MB extra after 14 loops
DummyModule (
(rnn1): LSTM(400, 400, bias=False, batch_first=True)
(rnn2): LSTM(400, 400, bias=False, batch_first=True)
(rnn3): LSTM(400, 400, bias=False, batch_first=True)
(rnn4): LSTM(400, 400, bias=False, batch_first=True)
)
at epoch 0: consuming extra 0.0 MB
at epoch 1: consuming extra 0.3 MB
at epoch 2: consuming extra 0.5 MB
at epoch 3: consuming extra 0.8 MB
at epoch 4: consuming extra 1.0 MB
at epoch 5: consuming extra 1.3 MB
at epoch 6: consuming extra 1.5 MB
at epoch 7: consuming extra 1.8 MB
at epoch 8: consuming extra 2.0 MB
at epoch 9: consuming extra 2.5 MB
at epoch 10: consuming extra 2.8 MB
at epoch 11: consuming extra 3.0 MB
at epoch 12: consuming extra 3.3 MB
at epoch 13: consuming extra 3.5 MB
at epoch 14: consuming extra 3.8 MB
at epoch 15: consuming extra 4.0 MB
at epoch 16: consuming extra 4.3 MB
at epoch 17: consuming extra 4.5 MB
at epoch 18: consuming extra 5.0 MB
at epoch 19: consuming extra 5.3 MB
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment