Skip to content

Instantly share code, notes, and snippets.

@shuuchen
Last active December 1, 2020 03:27
Show Gist options
  • Save shuuchen/7463009370e9ddf77e649f3fec259024 to your computer and use it in GitHub Desktop.
Save shuuchen/7463009370e9ddf77e649f3fec259024 to your computer and use it in GitHub Desktop.
Multi-GPU sync-batch-norm test
import os
import argparse
import torch
import shutil
import torch.optim as optim
import torch.nn as nn
import numpy as np
import pandas as pd
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
class Model(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
self.bn = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
def main():
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29508'
world_size = 2
mp.spawn(create_process,
args=(world_size,),
nprocs=world_size,
join=True)
def create_process(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
model = Model(2,2).to(rank)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DDP(model, device_ids=[rank])
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0003)
tensor = torch.rand(1,2,5,5)
label = torch.rand(1,2,5,5)
train(model, optimizer, rank, tensor, label)
def train(model, optimizer, rank, tensor, label):
criterion = nn.L1Loss()
tensor = tensor.to(rank)
label = label.to(rank).float()
output = model(tensor)
loss = criterion(output, label)
model.zero_grad()
loss.backward()
optimizer.step()
dist.reduce(loss, 0, dist.ReduceOp.SUM)
print('exit:', rank, loss.item())
if __name__=="__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment