Last active
May 17, 2024 23:53
-
-
Save youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b to your computer and use it in GitHub Desktop.
object broadcast comparison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
import os | |
import multiprocessing | |
import multiprocessing.shared_memory | |
import io | |
import pickle | |
N_warmup = 10 # warmup N_warmup times | |
N = 100 # repeat N times | |
backend = os.getenv('BACKEND', 'gloo') # 'gloo' or 'nccl' | |
align_memory = int(os.getenv('ALIGN', '0')) # whether to align memory | |
def align_broadcast_obj(obj, src=0, rank=0, size_upper_bound=1024 * 1024): | |
if rank == src: | |
s = pickle.dumps(obj) | |
# tried to avoid this copy, but it seems pickle cannot directly dump to a memory buffer | |
buffer_tensor[:len(s)].copy_(torch.frombuffer(s, dtype=torch.uint8)) | |
dist.broadcast(buffer_tensor[:size_upper_bound], src=src) | |
if rank != src: | |
obj = pickle.loads(memoryview(buffer)) | |
return obj | |
import time | |
def one_trial(size): | |
# measure the time for broadcasting a tensor | |
total_time = [] | |
for i in range(N + N_warmup): | |
if dist.get_rank() == 0: | |
d = [1] * size | |
else: | |
d = [] | |
if not align_memory: | |
start = time.time() | |
container = [d] | |
dist.broadcast_object_list(container, src=0) | |
d = container[0] | |
end = time.time() | |
else: | |
size_upper_bound = len(pickle.dumps([1] * size)) | |
ALIGNMENT = 256 | |
if size_upper_bound % ALIGNMENT != 0: | |
size_upper_bound += ALIGNMENT - size_upper_bound % ALIGNMENT | |
start = time.time() | |
d = align_broadcast_obj(d, src=0, rank=rank, size_upper_bound=size_upper_bound) | |
end = time.time() | |
assert len(d) == size | |
if i >= N_warmup: | |
total_time.append(end - start) | |
total_time = torch.tensor(total_time) * 1e6 # in microseconds | |
mean = total_time.mean().item() | |
std = total_time.std().item() | |
return mean, std | |
dist.init_process_group(backend=backend) | |
if backend == 'nccl': | |
torch.cuda.set_device(dist.get_rank()) | |
rank = dist.get_rank() | |
if align_memory: | |
buffer = bytearray(1024 * 1024) | |
buffer_tensor = torch.frombuffer(memoryview(buffer), dtype=torch.uint8) | |
# print the header | |
if dist.get_rank() == 0: | |
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)"])) | |
for size in list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]: | |
mean, std = one_trial(size) | |
if dist.get_rank() == 0: | |
pickle_bytes = len(pickle.dumps([1] * size)) | |
print("\t".join([f"{size}", f"{pickle_bytes}", f"{mean:.3f}", f"{std:.3f}"])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run:
BACKEND=gloo ALIGN=0 torchrun --nproc-per-node=4 test.py
BACKEND=gloo ALIGN=1 torchrun --nproc-per-node=4 test.py
BACKEND=nccl ALIGN=0 torchrun --nproc-per-node=4 test.py