Skip to content

Instantly share code, notes, and snippets.

@Ushk
Last active September 15, 2018 14:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Ushk/2cc0567d83a794a04c321c82a0facf63 to your computer and use it in GitHub Desktop.
Save Ushk/2cc0567d83a794a04c321c82a0facf63 to your computer and use it in GitHub Desktop.
Demo Script to show Segmentation Fault in Pytorch DataParallel & Checkpoint
import torch
import torch.nn as nn
from torch import optim
import torch.utils.checkpoint as chk
import torch.nn.functional as F
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,3'
import faulthandler
faulthandler.enable()
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
self.blocks = nn.ModuleDict()
self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=True)
self.blocks['0'] = self.conv0
self.conv1 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True)
self.blocks['1'] = self.conv1
self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=True)
self.blocks['2'] = self.conv2
self.conv3 = nn.Conv2d(64, 20, kernel_size=3, stride=1, padding=1, bias=True)
self.blocks['3'] = self.conv3
def forward(self, x):
x = self.blocks['0'](x)
#x1 = self.blocks['1'](x)
x1 = chk.checkpoint(self.conv1,x)
#x2 = self.blocks['2'](x)
x2 = chk.checkpoint(self.conv2,x)
x = torch.cat((x1,x2),1)
x = self.blocks['3'](x)
return x
test_model = model()
test_model = nn.DataParallel(test_model)
test_model = test_model.cuda()
loss = nn.MSELoss()
optimizer = optim.SGD(test_model.module.parameters(), lr = 0.01)
for i in range(100):
print(i)
data = torch.rand(4, 3, 15,15)
labels = torch.rand(4,20, 15,15).cuda()
test_preds = test_model(data)
optimizer.zero_grad()
test_loss = loss(test_preds, labels)
test_loss.backward()
optimizer.step()
print('Finished')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment