Last active
June 27, 2024 21:32
-
-
Save youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b to your computer and use it in GitHub Desktop.
object broadcast comparison
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
import os | |
import multiprocessing | |
import multiprocessing.shared_memory | |
import io | |
import pickle | |
N_warmup = 10 # warmup N_warmup times | |
N = 100 # repeat N times | |
backend = os.getenv('BACKEND', 'gloo') # 'gloo' or 'nccl' | |
align_memory = int(os.getenv('ALIGN', '0')) # whether to align memory | |
def align_broadcast_obj(obj, src=0, rank=0, size_upper_bound=1024 * 1024): | |
if rank == src: | |
s = pickle.dumps(obj) | |
# tried to avoid this copy, but it seems pickle cannot directly dump to a memory buffer | |
buffer_tensor[:len(s)].copy_(torch.frombuffer(s, dtype=torch.uint8)) | |
dist.broadcast(buffer_tensor[:size_upper_bound], src=src) | |
if rank != src: | |
obj = pickle.loads(memoryview(buffer)) | |
return obj | |
import time | |
def one_trial(size): | |
# measure the time for broadcasting a tensor | |
total_time = [] | |
for i in range(N + N_warmup): | |
if dist.get_rank() == 0: | |
d = [1] * size | |
else: | |
d = [] | |
if not align_memory: | |
start = time.time() | |
container = [d] | |
dist.broadcast_object_list(container, src=0) | |
d = container[0] | |
end = time.time() | |
else: | |
size_upper_bound = len(pickle.dumps([1] * size)) | |
ALIGNMENT = 256 | |
if size_upper_bound % ALIGNMENT != 0: | |
size_upper_bound += ALIGNMENT - size_upper_bound % ALIGNMENT | |
start = time.time() | |
d = align_broadcast_obj(d, src=0, rank=rank, size_upper_bound=size_upper_bound) | |
end = time.time() | |
assert len(d) == size | |
if i >= N_warmup: | |
total_time.append(end - start) | |
total_time = torch.tensor(total_time) * 1e6 # in microseconds | |
mean = total_time.mean().item() | |
std = total_time.std().item() | |
return mean, std | |
dist.init_process_group(backend=backend) | |
if backend == 'nccl': | |
torch.cuda.set_device(dist.get_rank()) | |
rank = dist.get_rank() | |
if align_memory: | |
buffer = bytearray(1024 * 1024) | |
buffer_tensor = torch.frombuffer(memoryview(buffer), dtype=torch.uint8) | |
# print the header | |
if dist.get_rank() == 0: | |
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)"])) | |
for size in list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]: | |
mean, std = one_trial(size) | |
if dist.get_rank() == 0: | |
pickle_bytes = len(pickle.dumps([1] * size)) | |
print("\t".join([f"{size}", f"{pickle_bytes}", f"{mean:.3f}", f"{std:.3f}"])) |
shared memory implementation:
import time
import multiprocessing
import pickle
# import msgpack as pickle
import torch
from vllm.distributed.device_communicators.shm_broadcast import ShmRingBuffer, ShmRingBufferIO
N_warmup = 10
N_message = 100
sizes = list(range(101, 1024, 101)) + [2 ** i for i in range(10, 19)]
def subscriber_process(size, buffer, queue, rank):
reader = ShmRingBufferIO(buffer, rank)
# Signal ready
queue.get()
latency = 0
for i in range(N_message):
data = reader.dequeue()
now = time.time()
# elapsed time, from the writer writes the message to the reader reads it
latency += now - data[0]
latency = latency / N_message * 1e6
assert len(data) == size
queue.put(latency)
def main():
print("\t".join(["size", "pickle_bytes", "mean (us)", "std (us)",]))
for size in sizes:
num_subscribers = 4
queue = multiprocessing.Queue()
buffer = ShmRingBuffer(num_subscribers, 1024 * 1024, 10)
processes = []
for i in range(num_subscribers):
p = multiprocessing.Process(target=subscriber_process, args=(size, buffer, queue, i))
p.start()
processes.append(p)
writer = ShmRingBufferIO(buffer, -1)
# Wait for all subscribers to signal ready
for _ in range(num_subscribers):
queue.put(b'')
while not queue.empty():
pass
for i in range(N_message):
d = [1] * size
now = time.time()
d[0] = now
writer.enqueue(d)
latencies = []
for _ in range(num_subscribers):
latency = queue.get()
latencies.append(latency)
latencies = torch.tensor(latencies)
l_mean = latencies.mean().item()
l_std = latencies.std().item()
pickle_bytes = len(pickle.dumps([1] * size))
print("\t".join([f"{size}", f"{pickle_bytes}", f"{l_mean:.3f}", f"{l_std:.3f}"]))
# Wait for all processes to complete
for p in processes:
p.join()
if __name__ == "__main__":
main()
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
message queue implementation for broadcast: