Created
March 21, 2017 21:22
-
-
Save t-vi/60515a24e1cbc4dc87897a6d8c224698 to your computer and use it in GitHub Desktop.
Torch bidirectional LSTM memory consumption
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import torch | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import gc | |
# helper function to get rss size, see stat(5) under statm. This is in pages (4k on my linux) | |
def memory_usage(): | |
return int(open('/proc/self/statm').read().split()[1]) | |
class DummyModule(nn.Module): | |
def __init__(self, bidirectional=True): | |
super(DummyModule, self).__init__() | |
self.bidirectional = bidirectional | |
self.rnn1 = nn.LSTM(input_size=400, hidden_size=400, | |
bidirectional=bidirectional, bias=False, batch_first=True) | |
self.rnn2 = nn.LSTM(input_size=400, hidden_size=400, | |
bidirectional=bidirectional, bias=False, batch_first=True) | |
self.rnn3 = nn.LSTM(input_size=400, hidden_size=400, | |
bidirectional=bidirectional, bias=False, batch_first=True) | |
self.rnn4 = nn.LSTM(input_size=400, hidden_size=400, | |
bidirectional=bidirectional, bias=False, batch_first=True) | |
def forward(self, x): | |
x,_ = self.rnn1(x) | |
x = x.contiguous() | |
if self.bidirectional: | |
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) | |
x,_ = self.rnn2(x) | |
x = x.contiguous() | |
if self.bidirectional: | |
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) | |
x,_ = self.rnn3(x) | |
x = x.contiguous() | |
if self.bidirectional: | |
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) | |
x,_ = self.rnn4(x) | |
return x | |
for bidirectional in [True, False]: | |
model = DummyModule(bidirectional=bidirectional) | |
model.cuda() | |
print (model) | |
initial_m = None | |
for i in range(20): | |
for j in range(100): | |
inputs = torch.zeros(32, 200, 400) | |
inputs = inputs.cuda() | |
inputs = Variable(inputs) | |
out = model(inputs) | |
gc.collect() | |
m = memory_usage() | |
if initial_m is None: | |
initial_m = m | |
print ("at epoch {}: consuming extra {:.1f} MB".format(i,(m-initial_m)/256)) | |
if (m-initial_m)>25600: | |
print ("consumed more than 100 MB extra after {} loops".format(i)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DummyModule ( | |
(rnn1): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True) | |
(rnn2): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True) | |
(rnn3): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True) | |
(rnn4): LSTM(400, 400, bias=False, batch_first=True, bidirectional=True) | |
) | |
at epoch 0: consuming extra 0.0 MB | |
at epoch 1: consuming extra 1.9 MB | |
at epoch 2: consuming extra 22.8 MB | |
at epoch 3: consuming extra 23.1 MB | |
at epoch 4: consuming extra 23.9 MB | |
at epoch 5: consuming extra 34.0 MB | |
at epoch 6: consuming extra 34.3 MB | |
at epoch 7: consuming extra 46.3 MB | |
at epoch 8: consuming extra 46.6 MB | |
at epoch 9: consuming extra 57.0 MB | |
at epoch 10: consuming extra 67.6 MB | |
at epoch 11: consuming extra 67.7 MB | |
at epoch 12: consuming extra 79.0 MB | |
at epoch 13: consuming extra 90.1 MB | |
at epoch 14: consuming extra 101.2 MB | |
consumed more than 100 MB extra after 14 loops | |
DummyModule ( | |
(rnn1): LSTM(400, 400, bias=False, batch_first=True) | |
(rnn2): LSTM(400, 400, bias=False, batch_first=True) | |
(rnn3): LSTM(400, 400, bias=False, batch_first=True) | |
(rnn4): LSTM(400, 400, bias=False, batch_first=True) | |
) | |
at epoch 0: consuming extra 0.0 MB | |
at epoch 1: consuming extra 0.3 MB | |
at epoch 2: consuming extra 0.5 MB | |
at epoch 3: consuming extra 0.8 MB | |
at epoch 4: consuming extra 1.0 MB | |
at epoch 5: consuming extra 1.3 MB | |
at epoch 6: consuming extra 1.5 MB | |
at epoch 7: consuming extra 1.8 MB | |
at epoch 8: consuming extra 2.0 MB | |
at epoch 9: consuming extra 2.5 MB | |
at epoch 10: consuming extra 2.8 MB | |
at epoch 11: consuming extra 3.0 MB | |
at epoch 12: consuming extra 3.3 MB | |
at epoch 13: consuming extra 3.5 MB | |
at epoch 14: consuming extra 3.8 MB | |
at epoch 15: consuming extra 4.0 MB | |
at epoch 16: consuming extra 4.3 MB | |
at epoch 17: consuming extra 4.5 MB | |
at epoch 18: consuming extra 5.0 MB | |
at epoch 19: consuming extra 5.3 MB |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment