Skip to content

Instantly share code, notes, and snippets.

@rohan-varma
Created March 16, 2021 22:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rohan-varma/904b7559133b8158f627534aa065528a to your computer and use it in GitHub Desktop.
Save rohan-varma/904b7559133b8158f627534aa065528a to your computer and use it in GitHub Desktop.
import argparse
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch
#print(torch.__file__) ; exit()
import torch.nn as nn
import torch.nn.parallel
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision
import torchvision.transforms as transforms
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(3, 50, kernel_size=5)
self.conv2 = nn.Conv2d(50, 100, kernel_size=5)
self.fc1 = nn.Linear(100 * 5 * 5, 300)
self.fc2 = nn.Linear(300, 10)
self.maxpool = nn.MaxPool2d(2, stride=2)
self.act_fn = nn.ReLU(inplace=True)
def forward(self, x, i=0, gpu=0, debug=False):
if debug and i == 0:
print(f'\n{i} inside forward: GPU {gpu} input {x.device} model {self.conv1.weight.device}\n')
x = self.conv1(x)
x = self.maxpool(x)
x = self.act_fn(x)
x = self.conv2(x)
x = self.maxpool(x)
x = self.act_fn(x)
x = x.reshape(x.shape[0], -1)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
print(f"{i} done with forward gpu {gpu} input {x.device} model {self.conv1.weight.device}")
return x
def run_inference(model, test_loader, args, debug=False):
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for i, (inputs, labels) in enumerate(test_loader):
inputs = inputs.cuda(args.gpu, non_blocking=True)
labels = labels.cuda(args.gpu, non_blocking=True)
print(f"GPU {args.gpu} inputs on device {inputs.device}, type is {type(inputs)} model on device {model.device}. CALLING FORWARD")
outputs = model(inputs, i=i, gpu=args.gpu, debug=debug)
torch.distributed.barrier()
print(f"Done with barrier")
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum()
test_acc = 100. * test_correct / test_total
return test_acc
def main(gpu_id, args, num_gpus):
args.gpu = gpu_id
model = ConvNet()
print(f"Using GPU {args.gpu}")
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=num_gpus, rank=args.gpu)
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=0.0001)
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.CIFAR10(root='.', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='.', train=False, download=True, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=500, shuffle=(train_sampler is None),
num_workers=4, pin_memory=True, drop_last=True, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=500, shuffle=False, num_workers=4)
if args.checkpoint is not None:
if 'full' in args.checkpoint:
model = torch.load(args.checkpoint, map_location=f'cuda:{args.gpu}')
# model = model.module uncommnent to fix
print(f"is_ddp {isinstance(model, torch.nn.parallel.DistributedDataParallel)}")
torch.cuda.set_device(args.gpu)
model.cuda(args.gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
else:
checkpoint = torch.load(args.checkpoint, map_location=f'cuda:{args.gpu}')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
test_acc = run_inference(model, test_loader, args, debug=True)
print(f'\nTest Accuracy {test_acc:.2f}\n')
return
criterion = nn.CrossEntropyLoss()
n_epochs =2
for epoch in range(n_epochs):
model.train()
train_sampler.set_epoch(epoch)
train_correct = 0
train_total = 0
train_correct, train_total = 1, 1
for i, (inputs, labels) in enumerate(train_loader):
continue
inputs = inputs.cuda(args.gpu)
labels = labels.cuda(args.gpu)
outputs = model(inputs, i=i, gpu=args.gpu)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum()
train_acc = 100. * train_correct / train_total
test_acc = run_inference(model, test_loader, args)
if args.gpu == 0:
print(f"Epoch {epoch:>3d} train {train_acc:.2f} test {test_acc:.2f}")
torch.save(model, 'model_full.pt')
torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, 'model.pt')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--checkpoint', type=str, default=None, metavar='', help='path to a model checkpoint')
args = parser.parse_args()
num_gpus = torch.cuda.device_count()
mp.spawn(main, nprocs=num_gpus, args=(args, num_gpus))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment