Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Last active September 20, 2023 00:57
Show Gist options
  • Save KeremTurgutlu/4a6f7078dc62f292c85b9903197c75f7 to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/4a6f7078dc62f292c85b9903197c75f7 to your computer and use it in GitHub Desktop.
Debugging: Distributed InfoNCE Loss
# CLIP contrastive loss is calculated all the negative batch samples from all the GPUs
# How to implement that?
# For more info: https://github.com/openai/CLIP/issues/29
import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import random
import torch.nn as nn
import torch.nn.functional as F
class GatherLayer(torch.autograd.Function):
'''Gather tensors from all process, supporting backward propagation.
https://github.com/open-mmlab/OpenSelfSup/blob/696d04950e55d504cf33bc83cfadbb4ece10fbae/openselfsup/models/utils/gather_layer.py
'''
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) \
for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out
def set_seed(s, reproducible=False):
"Set random seed for `random`, `torch`, and `numpy` (where available)"
try: torch.manual_seed(s)
except NameError: pass
try: torch.cuda.manual_seed_all(s)
except NameError: pass
try: np.random.seed(s%(2**32-1))
except NameError: pass
random.seed(s)
if reproducible:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup(rank, world_size):
if sys.platform == 'win32':
# Distributed package only covers collective communications with Gloo
# backend and FileStore on Windows platform. Set init_method parameter
# in init_process_group to a local file.
# Example init_method="file:///f:/libtmp/some_file"
init_method="file:///{your local file path}"
# initialize the process group
dist.init_process_group(
"gloo",
init_method=init_method,
rank=rank,
world_size=world_size
)
else:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def get_partition(rank):
set_seed(42)
image_inputs = torch.randn(8, 4)
set_seed(42)
text_inputs = torch.randn(8, 4)
return image_inputs[rank*4:(rank+1)*4], text_inputs[rank*4:(rank+1)*4]
def get_encoders():
set_seed(42)
image_encoder = nn.Linear(4, 2, bias=False)
set_seed(42)
text_encoder = nn.Linear(4, 2, bias=False)
return image_encoder, text_encoder
def demo_basic(rank, world_size, bs):
setup(rank, world_size)
# local partition of data
image_inputs, text_inputs = get_partition(rank)
# print(image_inputs)
# print(text_inputs)
# local copy of encoders
image_encoder, text_encoder = get_encoders()
# print(image_encoder.weight)
# print(text_encoder.weight)
# calculate embeddings
image_embeddings = image_encoder(image_inputs)
text_embeddings = text_encoder(text_inputs)
image_embeddings = F.normalize(image_embeddings)
text_embeddings = F.normalize(text_embeddings)
# image_embeddings = torch.cat(GatherLayer.apply(image_embeddings))
text_embeddings = torch.cat(GatherLayer.apply(text_embeddings))
print(text_embeddings)
# calculate contrastive loss with all batches across multiple devices
image_loss = F.cross_entropy(image_embeddings @ text_embeddings.T, torch.arange(rank*bs, (rank+1)*bs))
# get gradients
image_loss.backward()
print(f"Rank:{rank} Loss: {image_loss}")
# average gradient from all devices
average_gradients(image_encoder)
# check gradients
print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}")
# average gradient from all devices
average_gradients(text_encoder)
# check gradients
print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}")
def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def run_demo(demo_fn, world_size, bs):
mp.spawn(demo_fn,
args=(world_size,bs),
nprocs=world_size,
join=True)
if __name__ == "__main__":
run_demo(demo_basic, 2, 4)
@erikchwang
Copy link

erikchwang commented Jun 25, 2023

  1. In GatherLayer. backward(), do average reduce-scatter on the grads.
  2. In average_gradients(), do average all-reduce on the image encoder, and sum all-reduce on the text encoder.

@ahmdtaha
Copy link

In this post, @KeremTurgutlu noted that something is wrong with text_encoder grad.

I created the following toy snippet -- that uses a single GPU -- to get the "groundtruth" gradient for both the image and text encoder

def main():
    set_seed(42)
    image_inputs = torch.randn(8, 4)
    set_seed(42)
    text_inputs = torch.randn(8, 4)
    set_seed(42)
    image_encoder = nn.Linear(4, 2, bias=False)
    set_seed(42)
    text_encoder = nn.Linear(4, 2, bias=False)

    image_embeddings = image_encoder(image_inputs)
    text_embeddings = text_encoder(text_inputs)

    image_embeddings = F.normalize(image_embeddings)
    text_embeddings = F.normalize(text_embeddings)

    # image_embeddings.shape, text_embeddings.shape

    cosine_sim = image_embeddings @ text_embeddings.T

    loss = F.cross_entropy(cosine_sim, torch.arange(len(cosine_sim)), reduction="none")
    print('Loss ', loss)

    loss = loss.mean()
    print('Mean ', loss)

    loss.backward()
    print('image_encoder ', image_encoder.weight.grad)
    print('text_encoder ', text_encoder.weight.grad)

This snippet gave the following output

image_encoder  tensor([[-0.0559,  0.0055, -0.0663,  0.0330],
        [ 0.1006, -0.0770,  0.0099,  0.0591]])
text_encoder  tensor([[-0.0575,  0.0046, -0.0505,  0.0319],
        [ 0.0871, -0.0862,  0.0119,  0.0704]])

Which indicates something indeed wrong with the text_encoder gradient.

I inspected his code further. The bug seems related to GatherLayer
I noticed minor differences between Kerem's GatherLayer and PyTorch's _AllGather (e.g., this line). Since _AllGather is already called by torch.distributed.nn.functional.all_gather, I decided to call this directly.

Now, my multi_gpu code looks like

def demo_basic(rank, world_size, bs):
    setup(rank, world_size)

    # local partition of data
    image_inputs, text_inputs = get_partition(rank)

    # local copy of encoders
    image_encoder, text_encoder = get_encoders()

    # calculate embeddings
    image_embeddings = image_encoder(image_inputs)
    text_embeddings = text_encoder(text_inputs)
    image_embeddings = F.normalize(image_embeddings)
    text_embeddings = F.normalize(text_embeddings)


    # >>>>>>>>>>>>>>>> Changes Here <<<<<<<<<<<<<<<<
    import torch.distributed.nn.functional as dist_nn                    
    all_text_embeddings = dist_nn.all_gather(text_embeddings) 
    all_text_embeddings = torch.cat(all_text_embeddings) 

    print(text_embeddings.shape)

    # calculate contrastive loss with all batches across multiple devices
    image_loss = F.cross_entropy(
        image_embeddings @ all_text_embeddings.T,
        torch.arange(rank * bs, (rank + 1) * bs),
        reduction="none",
    )
    print(image_loss)
    image_loss = image_loss.mean()
    print(image_loss)
    # get gradients
    image_loss.backward()

    # # average gradient from all devices
    average_gradients(text_encoder)
    # # check gradients
    print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}")

    # # average gradient from all devices
    average_gradients(image_encoder)
    # # check gradients
    print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}")

This gives the following output


Rank:1 image_encoder.weight.grad: tensor([[-0.0559,  0.0055, -0.0663,  0.0330],
        [ 0.1006, -0.0770,  0.0099,  0.0591]])
Rank:0 image_encoder.weight.grad: tensor([[-0.0559,  0.0055, -0.0663,  0.0330],
        [ 0.1006, -0.0770,  0.0099,  0.0591]])

Rank:0 text_encoder.weight.grad: tensor([[-0.0575,  0.0046, -0.0505,  0.0319],
        [ 0.0871, -0.0862,  0.0119,  0.0704]])
Rank:1 text_encoder.weight.grad: tensor([[-0.0575,  0.0046, -0.0505,  0.0319],
        [ 0.0871, -0.0862,  0.0119,  0.0704]])

Which I believe is the correct output. I hope this is the error that Kerem Turgutlu referred to and there are no more errors
I hope this helps

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment