Last active
September 20, 2023 00:57
-
-
Save KeremTurgutlu/4a6f7078dc62f292c85b9903197c75f7 to your computer and use it in GitHub Desktop.
Debugging: Distributed InfoNCE Loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CLIP contrastive loss is calculated all the negative batch samples from all the GPUs | |
# How to implement that? | |
# For more info: https://github.com/openai/CLIP/issues/29 | |
import os | |
import sys | |
import tempfile | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import random | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class GatherLayer(torch.autograd.Function): | |
'''Gather tensors from all process, supporting backward propagation. | |
https://github.com/open-mmlab/OpenSelfSup/blob/696d04950e55d504cf33bc83cfadbb4ece10fbae/openselfsup/models/utils/gather_layer.py | |
''' | |
@staticmethod | |
def forward(ctx, input): | |
ctx.save_for_backward(input) | |
output = [torch.zeros_like(input) \ | |
for _ in range(dist.get_world_size())] | |
dist.all_gather(output, input) | |
return tuple(output) | |
@staticmethod | |
def backward(ctx, *grads): | |
input, = ctx.saved_tensors | |
grad_out = torch.zeros_like(input) | |
grad_out[:] = grads[dist.get_rank()] | |
return grad_out | |
def set_seed(s, reproducible=False): | |
"Set random seed for `random`, `torch`, and `numpy` (where available)" | |
try: torch.manual_seed(s) | |
except NameError: pass | |
try: torch.cuda.manual_seed_all(s) | |
except NameError: pass | |
try: np.random.seed(s%(2**32-1)) | |
except NameError: pass | |
random.seed(s) | |
if reproducible: | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def setup(rank, world_size): | |
if sys.platform == 'win32': | |
# Distributed package only covers collective communications with Gloo | |
# backend and FileStore on Windows platform. Set init_method parameter | |
# in init_process_group to a local file. | |
# Example init_method="file:///f:/libtmp/some_file" | |
init_method="file:///{your local file path}" | |
# initialize the process group | |
dist.init_process_group( | |
"gloo", | |
init_method=init_method, | |
rank=rank, | |
world_size=world_size | |
) | |
else: | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12355' | |
# initialize the process group | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
def cleanup(): | |
dist.destroy_process_group() | |
def get_partition(rank): | |
set_seed(42) | |
image_inputs = torch.randn(8, 4) | |
set_seed(42) | |
text_inputs = torch.randn(8, 4) | |
return image_inputs[rank*4:(rank+1)*4], text_inputs[rank*4:(rank+1)*4] | |
def get_encoders(): | |
set_seed(42) | |
image_encoder = nn.Linear(4, 2, bias=False) | |
set_seed(42) | |
text_encoder = nn.Linear(4, 2, bias=False) | |
return image_encoder, text_encoder | |
def demo_basic(rank, world_size, bs): | |
setup(rank, world_size) | |
# local partition of data | |
image_inputs, text_inputs = get_partition(rank) | |
# print(image_inputs) | |
# print(text_inputs) | |
# local copy of encoders | |
image_encoder, text_encoder = get_encoders() | |
# print(image_encoder.weight) | |
# print(text_encoder.weight) | |
# calculate embeddings | |
image_embeddings = image_encoder(image_inputs) | |
text_embeddings = text_encoder(text_inputs) | |
image_embeddings = F.normalize(image_embeddings) | |
text_embeddings = F.normalize(text_embeddings) | |
# image_embeddings = torch.cat(GatherLayer.apply(image_embeddings)) | |
text_embeddings = torch.cat(GatherLayer.apply(text_embeddings)) | |
print(text_embeddings) | |
# calculate contrastive loss with all batches across multiple devices | |
image_loss = F.cross_entropy(image_embeddings @ text_embeddings.T, torch.arange(rank*bs, (rank+1)*bs)) | |
# get gradients | |
image_loss.backward() | |
print(f"Rank:{rank} Loss: {image_loss}") | |
# average gradient from all devices | |
average_gradients(image_encoder) | |
# check gradients | |
print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}") | |
# average gradient from all devices | |
average_gradients(text_encoder) | |
# check gradients | |
print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}") | |
def average_gradients(model): | |
size = float(dist.get_world_size()) | |
for param in model.parameters(): | |
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) | |
param.grad.data /= size | |
def run_demo(demo_fn, world_size, bs): | |
mp.spawn(demo_fn, | |
args=(world_size,bs), | |
nprocs=world_size, | |
join=True) | |
if __name__ == "__main__": | |
run_demo(demo_basic, 2, 4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
In this post, @KeremTurgutlu noted that something is wrong with text_encoder grad.
I created the following toy snippet -- that uses a single GPU -- to get the "groundtruth" gradient for both the image and text encoder
This snippet gave the following output
Which indicates something indeed wrong with the text_encoder gradient.
I inspected his code further. The bug seems related to
GatherLayer
I noticed minor differences between Kerem's
GatherLayer
and PyTorch's _AllGather (e.g., this line). Since_AllGather
is already called bytorch.distributed.nn.functional.all_gather
, I decided to call this directly.Now, my multi_gpu code looks like
This gives the following output
Which I believe is the correct output. I hope this is the error that Kerem Turgutlu referred to and there are no more errors
I hope this helps