Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Debugging: Distributed InfoNCE Loss
# CLIP contrastive loss is calculated all the negative batch samples from all the GPUs
# How to implement that?
# For more info: https://github.com/openai/CLIP/issues/29
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import random
import torch.nn as nn
import torch.nn.functional as F
class GatherLayer(torch.autograd.Function):
'''Gather tensors from all process, supporting backward propagation.
https://github.com/open-mmlab/OpenSelfSup/blob/696d04950e55d504cf33bc83cfadbb4ece10fbae/openselfsup/models/utils/gather_layer.py
'''
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) \
for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out
def set_seed(s, reproducible=False):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try: torch.manual_seed(s)
except NameError: pass
try: torch.cuda.manual_seed_all(s)
except NameError: pass
try: np.random.seed(s%(2**32-1))
except NameError: pass
random.seed(s)
if reproducible:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup(rank, world_size):
if sys.platform == 'win32':
# Distributed package only covers collective communications with Gloo
# backend and FileStore on Windows platform. Set init_method parameter
# in init_process_group to a local file.
# Example init_method="file:///f:/libtmp/some_file"
init_method="file:///{your local file path}"
# initialize the process group
dist.init_process_group(
"gloo",
init_method=init_method,
rank=rank,
world_size=world_size
)
else:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def get_partition(rank):
set_seed(42)
image_inputs = torch.randn(8, 4)
set_seed(42)
text_inputs = torch.randn(8, 4)
return image_inputs[rank*4:(rank+1)*4], text_inputs[rank*4:(rank+1)*4]
def get_encoders():
set_seed(42)
image_encoder = nn.Linear(4, 2, bias=False)
set_seed(42)
text_encoder = nn.Linear(4, 2, bias=False)
return image_encoder, text_encoder
def demo_basic(rank, world_size, bs):
setup(rank, world_size)
# local partition of data
image_inputs, text_inputs = get_partition(rank)
# print(image_inputs)
# print(text_inputs)
# local copy of encoders
image_encoder, text_encoder = get_encoders()
# print(image_encoder.weight)
# print(text_encoder.weight)
# calculate embeddings
image_embeddings = image_encoder(image_inputs)
text_embeddings = text_encoder(text_inputs)
image_embeddings = F.normalize(image_embeddings)
text_embeddings = F.normalize(text_embeddings)
# image_embeddings = torch.cat(GatherLayer.apply(image_embeddings))
text_embeddings = torch.cat(GatherLayer.apply(text_embeddings))
print(text_embeddings)
# calculate contrastive loss with all batches across multiple devices
image_loss = F.cross_entropy(image_embeddings @ text_embeddings.T, torch.arange(rank*bs, (rank+1)*bs))
# get gradients
image_loss.backward()
print(f"Rank:{rank} Loss: {image_loss}")
# average gradient from all devices
average_gradients(image_encoder)
# check gradients
print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}")
# average gradient from all devices
average_gradients(text_encoder)
# check gradients
print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}")
def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def run_demo(demo_fn, world_size, bs):
mp.spawn(demo_fn,
args=(world_size,bs),
nprocs=world_size,
join=True)
if __name__ == "__main__":
run_demo(demo_basic, 2, 4)
@hanjf12
Copy link

hanjf12 commented Jun 10, 2021

After using GatherLayer text_encoder.weight.grad is there but values are wrong.
Output from new version:

Rank:0 Loss: 1.2145355939865112
Rank:1 Loss: 1.5723316669464111
Rank:0 image_encoder.weight.grad: tensor([[-0.0559,  0.0055, -0.0663,  0.0330],
        [ 0.1006, -0.0770,  0.0099,  0.0591]])
Rank:1 image_encoder.weight.grad: tensor([[-0.0559,  0.0055, -0.0663,  0.0330],
        [ 0.1006, -0.0770,  0.0099,  0.0591]])
Rank:0 text_encoder.weight.grad: tensor([[-0.0657,  0.0156, -0.0082,  0.0552],
        [ 0.1148, -0.1080,  0.0123,  0.0640]])
Rank:1 text_encoder.weight.grad: tensor([[-0.0657,  0.0156, -0.0082,  0.0552],
        [ 0.1148, -0.1080,  0.0123,  0.0640]])

hello, have you fix the problem?

I ended up using this: https://github.com/KeremTurgutlu/self_supervised/blob/main/self_supervised/dist.py

The two classes are same, so the problem of different text_encoder.weight.grad is still there?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment