Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created March 8, 2024 08:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mdouze/93854e55e210a03c9ca3475b09d7c3e7 to your computer and use it in GitHub Desktop.
Save mdouze/93854e55e210a03c9ca3475b09d7c3e7 to your computer and use it in GitHub Desktop.
"""
This example script shows how to handle a database sharded over n GPUs.
Each GPU issues a set of queries simultaneously. The queries are performed
over the sharded dataset and the results are sent back to the issuing GPU.
"""
import os
import argparse
import time
import sys
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.distributed
from torch.distributed import destroy_process_group, init_process_group
import faiss
import faiss.contrib.torch_utils
from faiss.contrib.datasets import SyntheticDataset
import triton
import triton.language as tl
def merge_result_table_torch(D0, I0, D1, I1):
"""
Pure troch implementation of merging sorted result tables.
"""
n, k = D0.shape
assert I0.shape == (n, k) and D1.shape == (n, k) and I1.shape == (n, k)
device = D0.device
shifts = torch.arange(n, device=device) * k
D = torch.empty_like(D0)
I = torch.empty_like(I0)
i0 = torch.zeros(n, dtype=torch.int64, device=device)
i1 = torch.zeros(n, dtype=torch.int64, device=device)
for i in range(k):
D0s = D0.ravel()[i0 + shifts]
D1s = D1.ravel()[i1 + shifts]
keep_0 = D0s < D1s
D[:, i] = torch.where(keep_0, D0s, D1s)
I0s = I0.ravel()[i0 + shifts]
I1s = I1.ravel()[i1 + shifts]
I[:, i] = torch.where(keep_0, I0s, I1s)
i0[keep_0] += 1
i1[torch.logical_not(keep_0)] += 1
return D, I
# direct translation of the torch implem
@triton.jit
def merge_result_tables_triton_kernel(
D0, I0, D1, I1,
D, I,
stride: int,
n : int,
bs: tl.constexpr,
k : tl.constexpr
):
bno = tl.program_id(0)
imin = bno * bs
i0 = (tl.arange(0, bs) + imin) * stride
i1 = (tl.arange(0, bs) + imin) * stride
ii = (tl.arange(0, bs) + imin) * stride
# avoid reading or writing out of bounds
mask = imin + tl.arange(0, bs) < n
for i in range(k):
D0s = tl.load(D0 + i0, mask=mask)
D1s = tl.load(D1 + i1, mask=mask)
keep_0 = D0s < D1s
Di = tl.where(keep_0, D0s, D1s)
tl.store(D + ii, Di)
I0s = tl.load(I0 + i0, mask=mask)
I1s = tl.load(I1 + i1, mask=mask)
Ii = tl.where(keep_0, I0s, I1s)
tl.store(I + ii, Ii, mask=mask)
i0 += keep_0
i1 += 1 - keep_0
ii += 1
def merge_result_table_triton(D0, I0, D1, I1):
"""
Triton code to merge result tables. It is assumed it will be called often
with the same k (number of columns of the result tables)
"""
n, k = D0.shape
assert I0.shape == (n, k) and D1.shape == (n, k) and I1.shape == (n, k)
assert D0.stride(0) == I0.stride(0) == D1.stride(0) == I1.stride(0)
# number of rows to process in 1 kernel call
bs = 256
grid = ((n + bs - 1) // bs, )
D = torch.empty_like(D0)
I = torch.empty_like(I0)
merge_result_tables_triton_kernel[grid](
D0, I0, D1, I1,
D, I,
D.stride(0), n,
bs, k
)
return D, I
def merge_result_tables(D0, I0, D1, I1, variant='triton'):
""" flatten dimensions and call merge_result_table_torch with 2D tensors """
shape = I0.shape
mshape = (np.prod(shape[:-1]), shape[-1])
func = merge_result_table_torch if variant == 'torch' else merge_result_table_triton
D, I = func(
D0.reshape(mshape), I0.reshape(mshape),
D1.reshape(mshape), I1.reshape(mshape)
)
return D.reshape(D0.shape), I.reshape(I0.shape)
class DistibutedSearch:
"""
Object that contains a database as a flat torch array on GPU.
It is assumed that it is instanciated on several GPUs with different arrays.
The combined objects offer a search functionality from a single GPU or
parallel search from all GPUs.
The search is performed on the concatenation of the arrays of different GPUs
(ie. the database is sharded).
"""
def __init__(self, rank, world_size, xb, merge_variant):
self.rank = rank
self.world_size = world_size
self.xb = xb
self.merge_variant = merge_variant
device = "cuda:%d" % rank
# broadcast dataset sizes
shard_sizes = torch.zeros(world_size, dtype=torch.int64).to(device)
tensor_list = [shard_sizes[i:i+1] for i in range(world_size)]
shard_sizes[rank] = len(xb)
torch.distributed.all_gather(tensor_list, tensor_list[rank])
shard_sizes = shard_sizes.cpu().numpy()
shard_lims = np.zeros(world_size + 1, dtype=int)
shard_lims[1:] = np.cumsum(shard_sizes)
self.shard_lims = shard_lims
self.res = faiss.StandardGpuResources()
# make sure it uses the same stream as pytorch
self.res.setDefaultNullStreamAllDevices()
def search(self, src, xq_or_nq, k):
"""
Perform search from GPU #src.
On src we get the query vectors xq
On other ranks we get the number of queries (nq)
"""
rank = self.rank
world_size = self.world_size
device = "cuda:%d" % rank
if self.rank == src:
xq = xq_or_nq
else:
assert type(xq_or_nq) == int
nq = xq_or_nq
xq = torch.empty(nq, self.xb.shape[1]).to(device)
torch.distributed.broadcast(xq, src=src)
D, I = faiss.knn_gpu(self.res, xq, self.xb, k, device=rank)
I += self.shard_lims[rank]
# tree merge, result in src
rank2 = (rank - src + world_size) % world_size
for round in range(50):
if 2 ** round >= world_size:
break
mask = 2 ** round - 1
if (rank2 & mask) != 0:
continue
peer = rank2 ^ (1 << round)
if peer >= world_size:
continue
peer2 = (peer + src) % world_size
if ((rank2 >> round) & 1) == 1:
torch.distributed.send(D, peer2)
torch.distributed.send(I, peer2)
else:
D1 = torch.empty_like(D)
I1 = torch.empty_like(I)
torch.distributed.recv(D1, peer2)
torch.distributed.recv(I1, peer2)
D, I = merge_result_table_torch(D, I, D1, I1)
# the result is returned only on GPU #src
if rank == src:
return D, I
else:
return None, None
def search_n2n(self, xq, k):
"""
All GPUs perform generate queries xq and receive the result of their own queries.
"""
rank = self.rank
world_size = self.world_size
assert (world_size & (world_size - 1)) == 0, "works only for powers of 2"
device = "cuda:%d" % rank
nq, d = xq.shape
# TODO: tile computation an all_gather (possibly via rotations)
all_xq = torch.empty((world_size, nq, d), device=device)
tensor_list = [all_xq[i] for i in range(world_size)]
torch.distributed.all_gather(tensor_list, xq)
# perform all searches in a big batch
D, I = faiss.knn_gpu(self.res, all_xq.reshape(nq * world_size, d), self.xb, k, device=rank)
# separate results per where the query comes from
D = D.reshape(world_size, nq, k)
I = I.reshape(world_size, nq, k)
I += self.shard_lims[rank]
# do rounds of merging
for round in range(50):
if (1<<round) >= world_size:
break
D0 = D[0::2]
D1 = D[1::2]
I0 = I[0::2]
I1 = I[1::2]
mshape = (I0.shape[0] * I0.shape[1], I0.shape[2])
peer = rank ^ (1 << round)
bit = (rank >> round) & 1
if bit == 0:
other_D0 = torch.empty_like(D0)
other_I0 = torch.empty_like(I0)
# print(f"{rank=:} merge from {peer} shape {other_I0.shape}", file=sys.stderr)
torch.distributed.recv(other_D0, peer)
torch.distributed.recv(other_I0, peer)
# print(f"{rank=:} send to {peer} shape {I1.shape}", file=sys.stderr)
torch.distributed.send(D1.contiguous(), peer)
torch.distributed.send(I1.contiguous(), peer)
D, I = merge_result_tables(D0, I0, other_D0, other_I0, variant=self.merge_variant)
else:
# print(f"{rank=:} send to {peer} shape {I0.shape}", file=sys.stderr)
torch.distributed.send(D0.contiguous(), peer)
torch.distributed.send(I0.contiguous(), peer)
other_D1 = torch.empty_like(D1)
other_I1 = torch.empty_like(I1)
# print(f"{rank=:} merge from {peer} shape {other_I1.shape}", file=sys.stderr)
torch.distributed.recv(other_D1, peer)
torch.distributed.recv(other_I1, peer)
D, I = merge_result_tables(D1, I1, other_D1, other_I1, variant=self.merge_variant)
# print(f"after merge {rank=:} {I=:}", file=sys.stderr)
return D[0], I[0]
def search_job(rank, port, args, ds, res_queue):
"""
Function that gets called on one process per GPU.
Instanciates a database and DistributedSearch object.
Performs a few rounds of searching with different queries.
Reports timings and verifies results.
"""
world_size = args.ngpu
print(f"Start search_job {rank=:}")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = "cuda:%d" % rank
if args.nthread != -1:
faiss.omp_set_num_threads(args.nthread)
if args.cache_dataset:
if rank == 0:
print("caching queries")
ds.get_database(rank, 0)
for step in range(args.niter):
ds.get_queries(step, rank, 0)
# prepare local shard of dataset
head = 0
xb = torch.from_numpy(ds.get_database(rank, head)).to(device)
distrib = DistibutedSearch(rank, world_size, xb, merge_variant=args.merge_table)
print(f"xb ok {rank=:}")
all_I = []
times = []
for step in range(args.niter):
# this can be slow ==> exclude from timing
xq = torch.from_numpy(ds.get_queries(step, rank, head)).to(device)
t0 = time.time()
if args.merge_type == "sequential": # sequential implem (slow)
for q_src in range(world_size):
if q_src == rank:
D, I = distrib.search(q_src, xq, args.k)
else:
# because all xq's have the same size
distrib.search(q_src, len(xq), args.k)
elif args.merge_type == "n2n": # faster implem
D, I = distrib.search_n2n(xq, args.k)
else:
assert False
all_I.append(I.cpu().numpy())
ti = time.time() - t0
if rank == 0:
print(f"End {step=:} time={ti:.3f} s")
times.append(ti)
destroy_process_group()
# report timings
if rank == 0:
times = np.array(times)
# print("timings: ", times)
print(f"mean time on {len(times)-2} iterations: {times[2:].mean():.3f} s")
print("checking results....")
# verification code
xb = np.vstack([ds.get_database(r, head) for r in range(world_size)])
for step in range(args.niter):
xq = ds.get_queries(step, rank, head)
D, ref_I = faiss.knn_gpu(distrib.res, xq, xb, args.k, device=rank)
ndiff = (all_I[step] != ref_I).sum()
assert ndiff < ref_I.size * 0.001, f"{ndiff=:} / {ref_I.size}"
class Dataset:
"""
Random dataset that provides database vectors and query vectors in a reproducible way.
"""
def __init__(self, args):
self.args = args
enable_cache = args.cache_dataset
self.cache_queries = {} if enable_cache else None
self.cache_database = {} if enable_cache else None
def get_database(self, rank, head):
if self.cache_database is not None and (rank, head) in self.cache_database:
return self.cache_database[(rank, head)]
seed = hash((rank, head))
args = self.args
bshape = (args.klen, args.head_dim)
x = faiss.rand_smooth_vectors(1, int(np.prod(bshape)), seed=seed).reshape(bshape)
if self.cache_database is not None:
self.cache_database[(rank, head)] = x
return x
def get_queries(self, step, rank, head):
if self.cache_queries is not None and (step, rank, head) in self.cache_queries:
return self.cache_queries[(step, rank, head)]
seed = hash((step, rank, head))
args = self.args
qshape = (args.bs_qlen, args.head_dim)
x = faiss.rand_smooth_vectors(1, int(np.prod(qshape)), seed=seed).reshape(qshape)
if self.cache_queries is not None:
self.cache_queries[(step, rank, head)] = x
return x
def main():
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group("problem size")
aa("--bs_qlen", default=4096, type=int, help="batch size for queries")
aa("--klen", default=10000, type=int, help="length of keys")
aa("--nhead", default=32, type=int, help="number of heads (not implemented)")
aa("--head_dim", default=128, type=int, help="head dimension")
aa("--niter", default=10, type=int, help="number of iterations")
aa("--k", default=10, type=int, help="nb resutls per query")
group = parser.add_argument_group("system parameters")
aa("--nthread", default=32, type=int, help="set nb threads")
aa("--ngpu", default=8, type=int, help="number of GPUs to use")
aa("--merge_type", default="n2n", choices=["sequential", "n2n"])
aa("--cache_dataset", default=False, action="store_true", help="caching of dataset")
aa("--merge_table", default="triton", choices=["triton", "torch"], help="")
args = parser.parse_args()
print(f"Running on {args.ngpu} GPUs")
assert torch.cuda.device_count() >= args.ngpu
ctx = mp.get_context('spawn')
res_queue = ctx.SimpleQueue()
ds = Dataset(args)
port = np.random.randint(50000, 65000)
mp.spawn(
search_job,
args=(port, args, ds, res_queue),
nprocs=args.ngpu,
)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment