-
-
Save KeremTurgutlu/4a6f7078dc62f292c85b9903197c75f7 to your computer and use it in GitHub Desktop.
# CLIP contrastive loss is calculated all the negative batch samples from all the GPUs | |
# How to implement that? | |
# For more info: https://github.com/openai/CLIP/issues/29 | |
import os | |
import sys | |
import tempfile | |
import torch | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.multiprocessing as mp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import random | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class GatherLayer(torch.autograd.Function): | |
'''Gather tensors from all process, supporting backward propagation. | |
https://github.com/open-mmlab/OpenSelfSup/blob/696d04950e55d504cf33bc83cfadbb4ece10fbae/openselfsup/models/utils/gather_layer.py | |
''' | |
@staticmethod | |
def forward(ctx, input): | |
ctx.save_for_backward(input) | |
output = [torch.zeros_like(input) \ | |
for _ in range(dist.get_world_size())] | |
dist.all_gather(output, input) | |
return tuple(output) | |
@staticmethod | |
def backward(ctx, *grads): | |
input, = ctx.saved_tensors | |
grad_out = torch.zeros_like(input) | |
grad_out[:] = grads[dist.get_rank()] | |
return grad_out | |
def set_seed(s, reproducible=False): | |
"Set random seed for `random`, `torch`, and `numpy` (where available)" | |
try: torch.manual_seed(s) | |
except NameError: pass | |
try: torch.cuda.manual_seed_all(s) | |
except NameError: pass | |
try: np.random.seed(s%(2**32-1)) | |
except NameError: pass | |
random.seed(s) | |
if reproducible: | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def setup(rank, world_size): | |
if sys.platform == 'win32': | |
# Distributed package only covers collective communications with Gloo | |
# backend and FileStore on Windows platform. Set init_method parameter | |
# in init_process_group to a local file. | |
# Example init_method="file:///f:/libtmp/some_file" | |
init_method="file:///{your local file path}" | |
# initialize the process group | |
dist.init_process_group( | |
"gloo", | |
init_method=init_method, | |
rank=rank, | |
world_size=world_size | |
) | |
else: | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12355' | |
# initialize the process group | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
def cleanup(): | |
dist.destroy_process_group() | |
def get_partition(rank): | |
set_seed(42) | |
image_inputs = torch.randn(8, 4) | |
set_seed(42) | |
text_inputs = torch.randn(8, 4) | |
return image_inputs[rank*4:(rank+1)*4], text_inputs[rank*4:(rank+1)*4] | |
def get_encoders(): | |
set_seed(42) | |
image_encoder = nn.Linear(4, 2, bias=False) | |
set_seed(42) | |
text_encoder = nn.Linear(4, 2, bias=False) | |
return image_encoder, text_encoder | |
def demo_basic(rank, world_size, bs): | |
setup(rank, world_size) | |
# local partition of data | |
image_inputs, text_inputs = get_partition(rank) | |
# print(image_inputs) | |
# print(text_inputs) | |
# local copy of encoders | |
image_encoder, text_encoder = get_encoders() | |
# print(image_encoder.weight) | |
# print(text_encoder.weight) | |
# calculate embeddings | |
image_embeddings = image_encoder(image_inputs) | |
text_embeddings = text_encoder(text_inputs) | |
image_embeddings = F.normalize(image_embeddings) | |
text_embeddings = F.normalize(text_embeddings) | |
# image_embeddings = torch.cat(GatherLayer.apply(image_embeddings)) | |
text_embeddings = torch.cat(GatherLayer.apply(text_embeddings)) | |
print(text_embeddings) | |
# calculate contrastive loss with all batches across multiple devices | |
image_loss = F.cross_entropy(image_embeddings @ text_embeddings.T, torch.arange(rank*bs, (rank+1)*bs)) | |
# get gradients | |
image_loss.backward() | |
print(f"Rank:{rank} Loss: {image_loss}") | |
# average gradient from all devices | |
average_gradients(image_encoder) | |
# check gradients | |
print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}") | |
# average gradient from all devices | |
average_gradients(text_encoder) | |
# check gradients | |
print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}") | |
def average_gradients(model): | |
size = float(dist.get_world_size()) | |
for param in model.parameters(): | |
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) | |
param.grad.data /= size | |
def run_demo(demo_fn, world_size, bs): | |
mp.spawn(demo_fn, | |
args=(world_size,bs), | |
nprocs=world_size, | |
join=True) | |
if __name__ == "__main__": | |
run_demo(demo_basic, 2, 4) |
erikchwang
commented
Jun 25, 2023
•
- In GatherLayer. backward(), do average reduce-scatter on the grads.
- In average_gradients(), do average all-reduce on the image encoder, and sum all-reduce on the text encoder.
In this post, @KeremTurgutlu noted that something is wrong with text_encoder grad.
I created the following toy snippet -- that uses a single GPU -- to get the "groundtruth" gradient for both the image and text encoder
def main():
set_seed(42)
image_inputs = torch.randn(8, 4)
set_seed(42)
text_inputs = torch.randn(8, 4)
set_seed(42)
image_encoder = nn.Linear(4, 2, bias=False)
set_seed(42)
text_encoder = nn.Linear(4, 2, bias=False)
image_embeddings = image_encoder(image_inputs)
text_embeddings = text_encoder(text_inputs)
image_embeddings = F.normalize(image_embeddings)
text_embeddings = F.normalize(text_embeddings)
# image_embeddings.shape, text_embeddings.shape
cosine_sim = image_embeddings @ text_embeddings.T
loss = F.cross_entropy(cosine_sim, torch.arange(len(cosine_sim)), reduction="none")
print('Loss ', loss)
loss = loss.mean()
print('Mean ', loss)
loss.backward()
print('image_encoder ', image_encoder.weight.grad)
print('text_encoder ', text_encoder.weight.grad)
This snippet gave the following output
image_encoder tensor([[-0.0559, 0.0055, -0.0663, 0.0330],
[ 0.1006, -0.0770, 0.0099, 0.0591]])
text_encoder tensor([[-0.0575, 0.0046, -0.0505, 0.0319],
[ 0.0871, -0.0862, 0.0119, 0.0704]])
Which indicates something indeed wrong with the text_encoder gradient.
I inspected his code further. The bug seems related to GatherLayer
I noticed minor differences between Kerem's GatherLayer
and PyTorch's _AllGather (e.g., this line). Since _AllGather
is already called by torch.distributed.nn.functional.all_gather
, I decided to call this directly.
Now, my multi_gpu code looks like
def demo_basic(rank, world_size, bs):
setup(rank, world_size)
# local partition of data
image_inputs, text_inputs = get_partition(rank)
# local copy of encoders
image_encoder, text_encoder = get_encoders()
# calculate embeddings
image_embeddings = image_encoder(image_inputs)
text_embeddings = text_encoder(text_inputs)
image_embeddings = F.normalize(image_embeddings)
text_embeddings = F.normalize(text_embeddings)
# >>>>>>>>>>>>>>>> Changes Here <<<<<<<<<<<<<<<<
import torch.distributed.nn.functional as dist_nn
all_text_embeddings = dist_nn.all_gather(text_embeddings)
all_text_embeddings = torch.cat(all_text_embeddings)
print(text_embeddings.shape)
# calculate contrastive loss with all batches across multiple devices
image_loss = F.cross_entropy(
image_embeddings @ all_text_embeddings.T,
torch.arange(rank * bs, (rank + 1) * bs),
reduction="none",
)
print(image_loss)
image_loss = image_loss.mean()
print(image_loss)
# get gradients
image_loss.backward()
# # average gradient from all devices
average_gradients(text_encoder)
# # check gradients
print(f"Rank:{rank} text_encoder.weight.grad: {text_encoder.weight.grad}")
# # average gradient from all devices
average_gradients(image_encoder)
# # check gradients
print(f"Rank:{rank} image_encoder.weight.grad: {image_encoder.weight.grad}")
This gives the following output
Rank:1 image_encoder.weight.grad: tensor([[-0.0559, 0.0055, -0.0663, 0.0330],
[ 0.1006, -0.0770, 0.0099, 0.0591]])
Rank:0 image_encoder.weight.grad: tensor([[-0.0559, 0.0055, -0.0663, 0.0330],
[ 0.1006, -0.0770, 0.0099, 0.0591]])
Rank:0 text_encoder.weight.grad: tensor([[-0.0575, 0.0046, -0.0505, 0.0319],
[ 0.0871, -0.0862, 0.0119, 0.0704]])
Rank:1 text_encoder.weight.grad: tensor([[-0.0575, 0.0046, -0.0505, 0.0319],
[ 0.0871, -0.0862, 0.0119, 0.0704]])
Which I believe is the correct output. I hope this is the error that Kerem Turgutlu referred to and there are no more errors
I hope this helps