Skip to content

Instantly share code, notes, and snippets.

@youkaichao
Last active June 27, 2024 21:32
Show Gist options
  • Save youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b to your computer and use it in GitHub Desktop.
Save youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b to your computer and use it in GitHub Desktop.
object broadcast comparison
import torch
import torch.distributed as dist
import os
import multiprocessing
import multiprocessing.shared_memory
import io
import pickle
N_warmup = 10 # warmup N_warmup times
N = 100 # repeat N times
backend = os.getenv('BACKEND', 'gloo') # 'gloo' or 'nccl'
align_memory = int(os.getenv('ALIGN', '0')) # whether to align memory
def align_broadcast_obj(obj, src=0, rank=0, size_upper_bound=1024 * 1024):
if rank == src:
s = pickle.dumps(obj)
# tried to avoid this copy, but it seems pickle cannot directly dump to a memory buffer
buffer_tensor[:len(s)].copy_(torch.frombuffer(s, dtype=torch.uint8))
dist.broadcast(buffer_tensor[:size_upper_bound], src=src)
if rank != src:
obj = pickle.loads(memoryview(buffer))
return obj
import time
def one_trial(size):
# measure the time for broadcasting a tensor
total_time = []
for i in range(N + N_warmup):
if dist.get_rank() == 0:
d = [1] * size
else:
d = []
if not align_memory:
start = time.time()
container = [d]
dist.broadcast_object_list(container, src=0)
d = container[0]
end = time.time()
else:
size_upper_bound = len(pickle.dumps([1] * size))
ALIGNMENT = 256
if size_upper_bound % ALIGNMENT != 0:
size_upper_bound += ALIGNMENT - size_upper_bound % ALIGNMENT
start = time.time()
d = align_broadcast_obj(d, src=0, rank=rank, size_upper_bound=size_upper_bound)
end = time.time()
assert len(d) == size
if i >= N_warmup:
total_time.append(end - start)
total_time = torch.tensor(total_time) * 1e6 # in microseconds
mean = total_time.mean().item()
std = total_time.std().item()
return mean, std
dist.init_process_group(backend=backend)
if backend == 'nccl':
torch.cuda.set_device(dist.get_rank())
rank = dist.get_rank()
if align_memory:
buffer = bytearray(1024 * 1024)
buffer_tensor = torch.frombuffer(memoryview(buffer), dtype=torch.uint8)
# print the header
if dist.get_rank() == 0:
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)"]))
for size in list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]:
mean, std = one_trial(size)
if dist.get_rank() == 0:
pickle_bytes = len(pickle.dumps([1] * size))
print("\t".join([f"{size}", f"{pickle_bytes}", f"{mean:.3f}", f"{std:.3f}"]))
@youkaichao
Copy link
Author

Run:

  • BACKEND=gloo ALIGN=0 torchrun --nproc-per-node=4 test.py
  • BACKEND=gloo ALIGN=1 torchrun --nproc-per-node=4 test.py
  • BACKEND=nccl ALIGN=0 torchrun --nproc-per-node=4 test.py

@youkaichao
Copy link
Author

youkaichao commented Jun 9, 2024

message queue implementation for broadcast:

import zmq
import time
import multiprocessing
import pickle
# import msgpack as pickle
import torch

N_warmup = 10
N_message = 100

sizes = list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]

def subscriber_process(size, pub_port, ready_port):
    context = zmq.Context()
    socket = context.socket(zmq.SUB)
    socket.connect(f"tcp://localhost:{pub_port}")
    socket.setsockopt_string(zmq.SUBSCRIBE, '')
    socket.setsockopt(zmq.RCVHWM, 1000)

    # Sync socket to signal ready
    sync_client = context.socket(zmq.REQ)
    sync_client.connect(f"tcp://localhost:{ready_port}")

    # Signal ready
    sync_client.send(b'')
    sync_client.recv()  # Wait for the go-ahead

    for i in range(N_warmup):
        message = socket.recv()

    # Signal ready
    sync_client.send(b'')
    sync_client.recv()  # Wait for the go-ahead

    latency = 0
    for i in range(N_message):
        message = socket.recv()
        data = pickle.loads(message)
        now = time.time()
        # elapsed time, from the writer writes the message to the reader reads it
        latency += now - data[0]
    latency = latency / N_message * 1e6
    assert len(data) == size
    d_send = pickle.dumps((latency, ))
    sync_client.send(d_send)
    sync_client.recv()  # Wait for the go-ahead

    socket.close()
    sync_client.close()
    context.term()

def main():

    print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)",]))

    for size in sizes:

        pub_port = "5559"
        ready_port = "5551"
        num_subscribers = 4
        context = zmq.Context()
        socket = context.socket(zmq.PUB)
        socket.bind(f"tcp://*:{pub_port}")
        socket.setsockopt(zmq.SNDHWM, 1000)

        # Sync socket to wait for subscribers to be ready
        sync_service = context.socket(zmq.REP)
        sync_service.bind(f"tcp://*:{ready_port}")

        processes = []
        for _ in range(num_subscribers):
            p = multiprocessing.Process(target=subscriber_process, args=(size, pub_port, ready_port))
            p.start()
            processes.append(p)

        # Wait for all subscribers to signal ready
        for _ in range(num_subscribers):
            sync_service.recv()
            sync_service.send(b'')

        for i in range(N_warmup):
            d = [1] * size
            now = time.time()
            d[0] = now
            data = pickle.dumps(d)
            socket.send(data)

        # Wait for all subscribers to signal ready
        for _ in range(num_subscribers):
            sync_service.recv()
            sync_service.send(b'')

        for i in range(N_message):
            d = [1] * size
            now = time.time()
            d[0] = now
            data = pickle.dumps(d)
            socket.send(data)

        latencies = []
        for _ in range(num_subscribers):
            b_recv = sync_service.recv()
            d_recv = pickle.loads(b_recv)
            (latency, ) = d_recv
            latencies.append(latency)
            sync_service.send(b'')

        latencies = torch.tensor(latencies)
        l_mean = latencies.mean().item()
        l_std = latencies.std().item()

        pickle_bytes = len(pickle.dumps([1] * size))
        print("\t".join([f"{size}", f"{pickle_bytes}", f"{l_mean:.3f}", f"{l_std:.3f}"]))

        # Wait for all processes to complete
        for p in processes:
            p.join()

        socket.close()
        sync_service.close()
        context.term()

if __name__ == "__main__":
    main()

@youkaichao
Copy link
Author

shared memory implementation:

import time
import multiprocessing
import pickle
# import msgpack as pickle
import torch
from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer, ShmRingBufferIO

N_warmup = 10
N_message = 100

sizes = list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]

def subscriber_process(size, buffer, queue, rank):
    reader = ShmRingBufferIO(buffer, rank)

    # Signal ready
    queue.get()

    latency = 0
    for i in range(N_message):
        data = reader.dequeue()
        now = time.time()
        # elapsed time, from the writer writes the message to the reader reads it
        latency += now - data[0]
    latency = latency / N_message * 1e6
    assert len(data) == size
    queue.put(latency)

def main():

    print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)",]))

    for size in sizes:

        num_subscribers = 4
        queue = multiprocessing.Queue()
        buffer = ShmRingBuffer(num_subscribers, 1024 * 1024, 10)
        processes = []
        for i in range(num_subscribers):
            p = multiprocessing.Process(target=subscriber_process, args=(size, buffer, queue, i))
            p.start()
            processes.append(p)

        writer = ShmRingBufferIO(buffer, -1)

        # Wait for all subscribers to signal ready
        for _ in range(num_subscribers):
            queue.put(b'')
        
        while not queue.empty():
            pass

        for i in range(N_message):
            d = [1] * size
            now = time.time()
            d[0] = now
            writer.enqueue(d)

        latencies = []
        for _ in range(num_subscribers):
            latency = queue.get()
            latencies.append(latency)

        latencies = torch.tensor(latencies)
        l_mean = latencies.mean().item()
        l_std = latencies.std().item()

        pickle_bytes = len(pickle.dumps([1] * size))
        print("\t".join([f"{size}", f"{pickle_bytes}", f"{l_mean:.3f}", f"{l_std:.3f}"]))

        # Wait for all processes to complete
        for p in processes:
            p.join()

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment