Skip to content

Instantly share code, notes, and snippets.

@calebh
Created May 3, 2020 19:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save calebh/fd7ff5014d9925b0e2d7473d9d4d5131 to your computer and use it in GitHub Desktop.
Save calebh/fd7ff5014d9925b0e2d7473d9d4d5131 to your computer and use it in GitHub Desktop.
Bizzare PyTorch memory issue
import torch
import pickle
import os
import random
import sourcenode
import torch.nn as nn
import torch.utils.checkpoint
import torch.utils.data
import math
import objprocessor
import ctypes
import gc
# Maximum number of asm instructions: 12216
# Maximum number of c nodes: 74035
if os.name == 'nt':
ctypes.cdll.LoadLibrary('caffe2_nvrtc.dll')
torch.manual_seed(0)
random.seed(0)
MAX_INPUT_SIZE = 5000
use_cuda = True
if use_cuda:
current_device = torch.device('cuda')
else:
current_device = torch.device('cpu')
class PickledDataset(torch.utils.data.Dataset):
def __init__(self, pickled_data_directory):
self.data_file_paths = []
for dir_name, subdir_list, file_list in os.walk(pickled_data_directory):
for file_name in file_list:
(name, extension) = os.path.splitext(file_name)
if extension == ".pickle":
self.data_file_paths.append(os.path.join(dir_name, file_name))
self.data_file_paths.sort()
def __getitem__(self, item):
with open(self.data_file_paths[item], 'rb') as f:
return pickle.load(f)
def __len__(self):
return len(self.data_file_paths)
neg_one = torch.tensor([[-1.0]])
padding_id = torch.tensor([sourcenode.Padding().get_id()], dtype=torch.long)
def collator(batch):
# an ASM tensor has shape (S, objprocessor.MAX_INSTR_SIZE)
asm_tensors = [data[0] for data in batch]
max_num_asm = max([ten.shape[0] for ten in asm_tensors])
asm_mask = torch.zeros((len(asm_tensors), max_num_asm), dtype=torch.bool)
for i in range(len(asm_tensors)):
ten = asm_tensors[i]
asm_mask[i, ten.shape[0]:max_num_asm] = True
asm_tensors = [torch.cat([ten, neg_one.expand((max_num_asm - ten.shape[0], ten.shape[1]))]) for ten in asm_tensors]
# a C tensor has shape (T, sourcenode.NODE_ID_END)
c_tensors = [data[1] for data in batch]
max_num_c = max([ten.shape[0] for ten in c_tensors])
c_mask = torch.zeros((len(c_tensors), max_num_c), dtype=torch.bool)
for i in range(len(c_tensors)):
ten = c_tensors[i]
c_mask[i, ten.shape[0]:max_num_c] = True
c_tensors = [torch.cat([ten, padding_id.expand(max_num_c - ten.shape[0])]) for ten in c_tensors]
return (torch.stack(asm_tensors, dim=1), torch.stack(c_tensors, dim=1), asm_mask, c_mask)
dataset = PickledDataset("../pickledtrainingdata")
ten_percent = int(0.1 * len(dataset))
training_size = len(dataset) - 2 * ten_percent
(testing_dataset, validation_dataset, training_dataset) = torch.utils.data.random_split(dataset, [ten_percent, ten_percent, training_size])
batch_size = 4
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, collate_fn=collator, shuffle=True)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class TransformerModel(nn.Module):
def __init__(self, nhead=8, dim_feedforward=1024, num_layers=6, dropout=0.1):
super(TransformerModel, self).__init__()
self.d_model = 512
self.input_embedding1 = nn.Linear(objprocessor.MAX_INSTR_SIZE, self.d_model)
self.relu1 = nn.ReLU()
self.input_embedding2 = nn.Linear(self.d_model, self.d_model)
self.relu2 = nn.ReLU()
self.input_embedding3 = nn.Linear(self.d_model, self.d_model)
self.output_embedding1 = nn.Linear(sourcenode.NODE_ID_END, self.d_model)
self.pos_encoder = PositionalEncoding(self.d_model, dropout, max_len=MAX_INPUT_SIZE)
self.transformer = nn.Transformer(d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward,
num_encoder_layers=num_layers, num_decoder_layers=num_layers, dropout=dropout)
self.output_linear = nn.Linear(self.d_model, sourcenode.NODE_ID_END)
# dummy_tensor must have requires_grad=True and is used to fool the checkpoint system into
# computing the gradient
def forward(self, src, tgt, src_padding_mask=None, tgt_padding_mask=None):
# src is a tensor of shape (S, N, objprocessor.MAX_INSTR_SIZE), where N = batch size,
# S = sequence length of input, objprocessor.MAX_INSTR_SIZE = channel size
# tgt is a tensor of shape (T, N, sourcenode.NODE_ID_END), where N = batch size,
# T = sequence length of output, sourcenode.NODE_ID_END = channel size
tgt_len = tgt.shape[0]
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_len)
tgt_mask = tgt_mask.to(tgt.device)
src = self.input_embedding1(src)
src = self.relu1(src)
src = self.input_embedding2(src)
src = self.relu2(src)
src = self.input_embedding3(src)
src = self.pos_encoder(src)
tgt = self.output_embedding1(tgt)
tgt = self.pos_encoder(tgt)
def wrapper(src, tgt, tgt_mask, src_padding_mask, tgt_padding_mask):
return self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask)
output = torch.utils.checkpoint.checkpoint(wrapper, src, tgt, tgt_mask, src_padding_mask, tgt_padding_mask)
output = self.output_linear(output)
return output
model = TransformerModel(dropout=0.1)
#model = nn.DataParallel(model, dim=1)
if use_cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()
model.train()
MAX_NUM_EPOCHS = 50
num_batches = math.ceil(len(training_dataset) / batch_size)
MODEL_SAVE_DIR = "models"
#MODEL_SAVE_DIR = "/content/drive/My Drive/models"
#torch.save(model.state_dict(), MODEL_SAVE_DIR + "/model_0.pt")
for epoch in range(MAX_NUM_EPOCHS):
i = 0
loss_sum = 0.0
loss_n = 0
exp_ma_loss = None
for (src, tgt_indices, src_padding_mask, tgt_padding_mask) in training_loader:
# src has shape (S, N, objprocessor.MAX_INSTR_SIZE) where S = sequence length of input, N = batch size,
# objprocessor.MAX_INSTR_SIZE = number of input channels
# tgt_indices has shape (T, N) where T = sequence length of output, N = batch size
# src_padding mask has shape (N, S)
# tgt_padding mask has shape (N, T)
if src.shape[0] > MAX_INPUT_SIZE or tgt_indices.shape[0] > MAX_INPUT_SIZE:
print("Input " + str(i) + " / " + str(num_batches) + " exceeded maximum input size")
continue
try:
# Really make sure that no local variables escape the scope
def run():
global src, tgt_indices, src_padding_mask, tgt_padding_mask, i, exp_ma_loss, loss_sum, loss_n
if use_cuda:
src = src.cuda()
tgt_indices = tgt_indices.cuda()
src_padding_mask = src_padding_mask.cuda()
tgt_padding_mask = tgt_padding_mask.cuda()
# Convert the indices to one hot vectors for use as input in the transformer model
tgt = torch.zeros((tgt_indices.shape[0], tgt_indices.shape[1], sourcenode.NODE_ID_END), device=current_device)
r1 = torch.arange(0, tgt_indices.shape[0], device=current_device).unsqueeze(1).expand_as(tgt_indices)
r2 = torch.arange(0, tgt_indices.shape[1], device=current_device).unsqueeze(0).expand_as(tgt_indices)
tgt[r1, r2, tgt_indices] = 1.0
del r1
del r2
optimizer.zero_grad()
output = model(src, tgt, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask)
del src
del tgt
del src_padding_mask
# Now remove the output that corresponds to padding entries
tgt_indices_padding_mask = (~tgt_padding_mask).t()
del tgt_padding_mask
tgt_indices_no_padding = torch.masked_select(tgt_indices, tgt_indices_padding_mask)
del tgt_indices
output_padding_mask = tgt_indices_padding_mask.unsqueeze(2)
output = torch.masked_select(output, output_padding_mask).view(-1, sourcenode.NODE_ID_END)
del output_padding_mask
print(torch.argmax(output.view(-1, sourcenode.NODE_ID_END), dim=1))
loss = criterion(output, tgt_indices_no_padding)
del tgt_indices_no_padding
del output
print("Loss " + str(epoch) + " - " + str(i) + " / " + str(num_batches), loss.item())
if exp_ma_loss is None:
exp_ma_loss = loss.item()
else:
coefficient = 0.001
exp_ma_loss = coefficient * loss.item() + (1.0 - coefficient) * exp_ma_loss
if i % 2500 == 0:
print("Exp MA Loss: " + str(epoch) + " - " + str(i) + " / " + str(num_batches) + " - " + str(exp_ma_loss))
loss_sum += loss.item()
loss_n += 1
loss.backward()
del loss
optimizer.step()
if i % 20000 == 0:
torch.save(model.state_dict(), MODEL_SAVE_DIR + "/model_" + str(epoch) + "_" + str(i) + ".pt")
i += 1
run()
gc.collect(generation=0)
gc.collect(generation=1)
gc.collect(generation=2)
torch.cuda.ipc_collect()
torch.cuda.synchronize(device=None)
if use_cuda:
torch.cuda.empty_cache()
except RuntimeError as exc:
if use_cuda and str(exc).startswith("CUDA out of memory"):
# Somehow asking for the memory summary fixes the issue
print(torch.cuda.memory_summary(device=None, abbreviated=True))
print("CUDA ran out of memory on " + str(i) + " / " + str(num_batches))
i += 1
else:
raise
print("Average loss at the end of epoch " + str(epoch) + ": " + str(loss_sum / loss_n))
torch.save(model.state_dict(), MODEL_SAVE_DIR + "/model_" + str(epoch) + ".pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment