Last active
June 27, 2024 21:32
-
-
Save youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b to your computer and use it in GitHub Desktop.
object broadcast comparison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
import os | |
import multiprocessing | |
import multiprocessing.shared_memory | |
import io | |
import pickle | |
N_warmup = 10 # warmup N_warmup times | |
N = 100 # repeat N times | |
backend = os.getenv('BACKEND', 'gloo') # 'gloo' or 'nccl' | |
align_memory = int(os.getenv('ALIGN', '0')) # whether to align memory | |
def align_broadcast_obj(obj, src=0, rank=0, size_upper_bound=1024 * 1024): | |
if rank == src: | |
s = pickle.dumps(obj) | |
# tried to avoid this copy, but it seems pickle cannot directly dump to a memory buffer | |
buffer_tensor[:len(s)].copy_(torch.frombuffer(s, dtype=torch.uint8)) | |
dist.broadcast(buffer_tensor[:size_upper_bound], src=src) | |
if rank != src: | |
obj = pickle.loads(memoryview(buffer)) | |
return obj | |
import time | |
def one_trial(size): | |
# measure the time for broadcasting a tensor | |
total_time = [] | |
for i in range(N + N_warmup): | |
if dist.get_rank() == 0: | |
d = [1] * size | |
else: | |
d = [] | |
if not align_memory: | |
start = time.time() | |
container = [d] | |
dist.broadcast_object_list(container, src=0) | |
d = container[0] | |
end = time.time() | |
else: | |
size_upper_bound = len(pickle.dumps([1] * size)) | |
ALIGNMENT = 256 | |
if size_upper_bound % ALIGNMENT != 0: | |
size_upper_bound += ALIGNMENT - size_upper_bound % ALIGNMENT | |
start = time.time() | |
d = align_broadcast_obj(d, src=0, rank=rank, size_upper_bound=size_upper_bound) | |
end = time.time() | |
assert len(d) == size | |
if i >= N_warmup: | |
total_time.append(end - start) | |
total_time = torch.tensor(total_time) * 1e6 # in microseconds | |
mean = total_time.mean().item() | |
std = total_time.std().item() | |
return mean, std | |
dist.init_process_group(backend=backend) | |
if backend == 'nccl': | |
torch.cuda.set_device(dist.get_rank()) | |
rank = dist.get_rank() | |
if align_memory: | |
buffer = bytearray(1024 * 1024) | |
buffer_tensor = torch.frombuffer(memoryview(buffer), dtype=torch.uint8) | |
# print the header | |
if dist.get_rank() == 0: | |
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)"])) | |
for size in list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]: | |
mean, std = one_trial(size) | |
if dist.get_rank() == 0: | |
pickle_bytes = len(pickle.dumps([1] * size)) | |
print("\t".join([f"{size}", f"{pickle_bytes}", f"{mean:.3f}", f"{std:.3f}"])) |
message queue implementation for broadcast:
import zmq
import time
import multiprocessing
import pickle
# import msgpack as pickle
import torch
N_warmup = 10
N_message = 100
sizes = list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]
def subscriber_process(size, pub_port, ready_port):
context = zmq.Context()
socket = context.socket(zmq.SUB)
socket.connect(f"tcp://localhost:{pub_port}")
socket.setsockopt_string(zmq.SUBSCRIBE, '')
socket.setsockopt(zmq.RCVHWM, 1000)
# Sync socket to signal ready
sync_client = context.socket(zmq.REQ)
sync_client.connect(f"tcp://localhost:{ready_port}")
# Signal ready
sync_client.send(b'')
sync_client.recv() # Wait for the go-ahead
for i in range(N_warmup):
message = socket.recv()
# Signal ready
sync_client.send(b'')
sync_client.recv() # Wait for the go-ahead
latency = 0
for i in range(N_message):
message = socket.recv()
data = pickle.loads(message)
now = time.time()
# elapsed time, from the writer writes the message to the reader reads it
latency += now - data[0]
latency = latency / N_message * 1e6
assert len(data) == size
d_send = pickle.dumps((latency, ))
sync_client.send(d_send)
sync_client.recv() # Wait for the go-ahead
socket.close()
sync_client.close()
context.term()
def main():
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)",]))
for size in sizes:
pub_port = "5559"
ready_port = "5551"
num_subscribers = 4
context = zmq.Context()
socket = context.socket(zmq.PUB)
socket.bind(f"tcp://*:{pub_port}")
socket.setsockopt(zmq.SNDHWM, 1000)
# Sync socket to wait for subscribers to be ready
sync_service = context.socket(zmq.REP)
sync_service.bind(f"tcp://*:{ready_port}")
processes = []
for _ in range(num_subscribers):
p = multiprocessing.Process(target=subscriber_process, args=(size, pub_port, ready_port))
p.start()
processes.append(p)
# Wait for all subscribers to signal ready
for _ in range(num_subscribers):
sync_service.recv()
sync_service.send(b'')
for i in range(N_warmup):
d = [1] * size
now = time.time()
d[0] = now
data = pickle.dumps(d)
socket.send(data)
# Wait for all subscribers to signal ready
for _ in range(num_subscribers):
sync_service.recv()
sync_service.send(b'')
for i in range(N_message):
d = [1] * size
now = time.time()
d[0] = now
data = pickle.dumps(d)
socket.send(data)
latencies = []
for _ in range(num_subscribers):
b_recv = sync_service.recv()
d_recv = pickle.loads(b_recv)
(latency, ) = d_recv
latencies.append(latency)
sync_service.send(b'')
latencies = torch.tensor(latencies)
l_mean = latencies.mean().item()
l_std = latencies.std().item()
pickle_bytes = len(pickle.dumps([1] * size))
print("\t".join([f"{size}", f"{pickle_bytes}", f"{l_mean:.3f}", f"{l_std:.3f}"]))
# Wait for all processes to complete
for p in processes:
p.join()
socket.close()
sync_service.close()
context.term()
if __name__ == "__main__":
main()
shared memory implementation:
import time
import multiprocessing
import pickle
# import msgpack as pickle
import torch
from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer, ShmRingBufferIO
N_warmup = 10
N_message = 100
sizes = list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]
def subscriber_process(size, buffer, queue, rank):
reader = ShmRingBufferIO(buffer, rank)
# Signal ready
queue.get()
latency = 0
for i in range(N_message):
data = reader.dequeue()
now = time.time()
# elapsed time, from the writer writes the message to the reader reads it
latency += now - data[0]
latency = latency / N_message * 1e6
assert len(data) == size
queue.put(latency)
def main():
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)",]))
for size in sizes:
num_subscribers = 4
queue = multiprocessing.Queue()
buffer = ShmRingBuffer(num_subscribers, 1024 * 1024, 10)
processes = []
for i in range(num_subscribers):
p = multiprocessing.Process(target=subscriber_process, args=(size, buffer, queue, i))
p.start()
processes.append(p)
writer = ShmRingBufferIO(buffer, -1)
# Wait for all subscribers to signal ready
for _ in range(num_subscribers):
queue.put(b'')
while not queue.empty():
pass
for i in range(N_message):
d = [1] * size
now = time.time()
d[0] = now
writer.enqueue(d)
latencies = []
for _ in range(num_subscribers):
latency = queue.get()
latencies.append(latency)
latencies = torch.tensor(latencies)
l_mean = latencies.mean().item()
l_std = latencies.std().item()
pickle_bytes = len(pickle.dumps([1] * size))
print("\t".join([f"{size}", f"{pickle_bytes}", f"{l_mean:.3f}", f"{l_std:.3f}"]))
# Wait for all processes to complete
for p in processes:
p.join()
if __name__ == "__main__":
main()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Run:
BACKEND=gloo ALIGN=0 torchrun --nproc-per-node=4 test.py
BACKEND=gloo ALIGN=1 torchrun --nproc-per-node=4 test.py
BACKEND=nccl ALIGN=0 torchrun --nproc-per-node=4 test.py