-
-
Save mdouze/93854e55e210a03c9ca3475b09d7c3e7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This example script shows how to handle a database sharded over n GPUs. | |
Each GPU issues a set of queries simultaneously. The queries are performed | |
over the sharded dataset and the results are sent back to the issuing GPU. | |
""" | |
import os | |
import argparse | |
import time | |
import sys | |
import numpy as np | |
import torch | |
import torch.multiprocessing as mp | |
import torch.distributed | |
from torch.distributed import destroy_process_group, init_process_group | |
import faiss | |
import faiss.contrib.torch_utils | |
from faiss.contrib.datasets import SyntheticDataset | |
import triton | |
import triton.language as tl | |
def merge_result_table_torch(D0, I0, D1, I1): | |
""" | |
Pure troch implementation of merging sorted result tables. | |
""" | |
n, k = D0.shape | |
assert I0.shape == (n, k) and D1.shape == (n, k) and I1.shape == (n, k) | |
device = D0.device | |
shifts = torch.arange(n, device=device) * k | |
D = torch.empty_like(D0) | |
I = torch.empty_like(I0) | |
i0 = torch.zeros(n, dtype=torch.int64, device=device) | |
i1 = torch.zeros(n, dtype=torch.int64, device=device) | |
for i in range(k): | |
D0s = D0.ravel()[i0 + shifts] | |
D1s = D1.ravel()[i1 + shifts] | |
keep_0 = D0s < D1s | |
D[:, i] = torch.where(keep_0, D0s, D1s) | |
I0s = I0.ravel()[i0 + shifts] | |
I1s = I1.ravel()[i1 + shifts] | |
I[:, i] = torch.where(keep_0, I0s, I1s) | |
i0[keep_0] += 1 | |
i1[torch.logical_not(keep_0)] += 1 | |
return D, I | |
# direct translation of the torch implem | |
@triton.jit | |
def merge_result_tables_triton_kernel( | |
D0, I0, D1, I1, | |
D, I, | |
stride: int, | |
n : int, | |
bs: tl.constexpr, | |
k : tl.constexpr | |
): | |
bno = tl.program_id(0) | |
imin = bno * bs | |
i0 = (tl.arange(0, bs) + imin) * stride | |
i1 = (tl.arange(0, bs) + imin) * stride | |
ii = (tl.arange(0, bs) + imin) * stride | |
# avoid reading or writing out of bounds | |
mask = imin + tl.arange(0, bs) < n | |
for i in range(k): | |
D0s = tl.load(D0 + i0, mask=mask) | |
D1s = tl.load(D1 + i1, mask=mask) | |
keep_0 = D0s < D1s | |
Di = tl.where(keep_0, D0s, D1s) | |
tl.store(D + ii, Di) | |
I0s = tl.load(I0 + i0, mask=mask) | |
I1s = tl.load(I1 + i1, mask=mask) | |
Ii = tl.where(keep_0, I0s, I1s) | |
tl.store(I + ii, Ii, mask=mask) | |
i0 += keep_0 | |
i1 += 1 - keep_0 | |
ii += 1 | |
def merge_result_table_triton(D0, I0, D1, I1): | |
""" | |
Triton code to merge result tables. It is assumed it will be called often | |
with the same k (number of columns of the result tables) | |
""" | |
n, k = D0.shape | |
assert I0.shape == (n, k) and D1.shape == (n, k) and I1.shape == (n, k) | |
assert D0.stride(0) == I0.stride(0) == D1.stride(0) == I1.stride(0) | |
# number of rows to process in 1 kernel call | |
bs = 256 | |
grid = ((n + bs - 1) // bs, ) | |
D = torch.empty_like(D0) | |
I = torch.empty_like(I0) | |
merge_result_tables_triton_kernel[grid]( | |
D0, I0, D1, I1, | |
D, I, | |
D.stride(0), n, | |
bs, k | |
) | |
return D, I | |
def merge_result_tables(D0, I0, D1, I1, variant='triton'): | |
""" flatten dimensions and call merge_result_table_torch with 2D tensors """ | |
shape = I0.shape | |
mshape = (np.prod(shape[:-1]), shape[-1]) | |
func = merge_result_table_torch if variant == 'torch' else merge_result_table_triton | |
D, I = func( | |
D0.reshape(mshape), I0.reshape(mshape), | |
D1.reshape(mshape), I1.reshape(mshape) | |
) | |
return D.reshape(D0.shape), I.reshape(I0.shape) | |
class DistibutedSearch: | |
""" | |
Object that contains a database as a flat torch array on GPU. | |
It is assumed that it is instanciated on several GPUs with different arrays. | |
The combined objects offer a search functionality from a single GPU or | |
parallel search from all GPUs. | |
The search is performed on the concatenation of the arrays of different GPUs | |
(ie. the database is sharded). | |
""" | |
def __init__(self, rank, world_size, xb, merge_variant): | |
self.rank = rank | |
self.world_size = world_size | |
self.xb = xb | |
self.merge_variant = merge_variant | |
device = "cuda:%d" % rank | |
# broadcast dataset sizes | |
shard_sizes = torch.zeros(world_size, dtype=torch.int64).to(device) | |
tensor_list = [shard_sizes[i:i+1] for i in range(world_size)] | |
shard_sizes[rank] = len(xb) | |
torch.distributed.all_gather(tensor_list, tensor_list[rank]) | |
shard_sizes = shard_sizes.cpu().numpy() | |
shard_lims = np.zeros(world_size + 1, dtype=int) | |
shard_lims[1:] = np.cumsum(shard_sizes) | |
self.shard_lims = shard_lims | |
self.res = faiss.StandardGpuResources() | |
# make sure it uses the same stream as pytorch | |
self.res.setDefaultNullStreamAllDevices() | |
def search(self, src, xq_or_nq, k): | |
""" | |
Perform search from GPU #src. | |
On src we get the query vectors xq | |
On other ranks we get the number of queries (nq) | |
""" | |
rank = self.rank | |
world_size = self.world_size | |
device = "cuda:%d" % rank | |
if self.rank == src: | |
xq = xq_or_nq | |
else: | |
assert type(xq_or_nq) == int | |
nq = xq_or_nq | |
xq = torch.empty(nq, self.xb.shape[1]).to(device) | |
torch.distributed.broadcast(xq, src=src) | |
D, I = faiss.knn_gpu(self.res, xq, self.xb, k, device=rank) | |
I += self.shard_lims[rank] | |
# tree merge, result in src | |
rank2 = (rank - src + world_size) % world_size | |
for round in range(50): | |
if 2 ** round >= world_size: | |
break | |
mask = 2 ** round - 1 | |
if (rank2 & mask) != 0: | |
continue | |
peer = rank2 ^ (1 << round) | |
if peer >= world_size: | |
continue | |
peer2 = (peer + src) % world_size | |
if ((rank2 >> round) & 1) == 1: | |
torch.distributed.send(D, peer2) | |
torch.distributed.send(I, peer2) | |
else: | |
D1 = torch.empty_like(D) | |
I1 = torch.empty_like(I) | |
torch.distributed.recv(D1, peer2) | |
torch.distributed.recv(I1, peer2) | |
D, I = merge_result_table_torch(D, I, D1, I1) | |
# the result is returned only on GPU #src | |
if rank == src: | |
return D, I | |
else: | |
return None, None | |
def search_n2n(self, xq, k): | |
""" | |
All GPUs perform generate queries xq and receive the result of their own queries. | |
""" | |
rank = self.rank | |
world_size = self.world_size | |
assert (world_size & (world_size - 1)) == 0, "works only for powers of 2" | |
device = "cuda:%d" % rank | |
nq, d = xq.shape | |
# TODO: tile computation an all_gather (possibly via rotations) | |
all_xq = torch.empty((world_size, nq, d), device=device) | |
tensor_list = [all_xq[i] for i in range(world_size)] | |
torch.distributed.all_gather(tensor_list, xq) | |
# perform all searches in a big batch | |
D, I = faiss.knn_gpu(self.res, all_xq.reshape(nq * world_size, d), self.xb, k, device=rank) | |
# separate results per where the query comes from | |
D = D.reshape(world_size, nq, k) | |
I = I.reshape(world_size, nq, k) | |
I += self.shard_lims[rank] | |
# do rounds of merging | |
for round in range(50): | |
if (1<<round) >= world_size: | |
break | |
D0 = D[0::2] | |
D1 = D[1::2] | |
I0 = I[0::2] | |
I1 = I[1::2] | |
mshape = (I0.shape[0] * I0.shape[1], I0.shape[2]) | |
peer = rank ^ (1 << round) | |
bit = (rank >> round) & 1 | |
if bit == 0: | |
other_D0 = torch.empty_like(D0) | |
other_I0 = torch.empty_like(I0) | |
# print(f"{rank=:} merge from {peer} shape {other_I0.shape}", file=sys.stderr) | |
torch.distributed.recv(other_D0, peer) | |
torch.distributed.recv(other_I0, peer) | |
# print(f"{rank=:} send to {peer} shape {I1.shape}", file=sys.stderr) | |
torch.distributed.send(D1.contiguous(), peer) | |
torch.distributed.send(I1.contiguous(), peer) | |
D, I = merge_result_tables(D0, I0, other_D0, other_I0, variant=self.merge_variant) | |
else: | |
# print(f"{rank=:} send to {peer} shape {I0.shape}", file=sys.stderr) | |
torch.distributed.send(D0.contiguous(), peer) | |
torch.distributed.send(I0.contiguous(), peer) | |
other_D1 = torch.empty_like(D1) | |
other_I1 = torch.empty_like(I1) | |
# print(f"{rank=:} merge from {peer} shape {other_I1.shape}", file=sys.stderr) | |
torch.distributed.recv(other_D1, peer) | |
torch.distributed.recv(other_I1, peer) | |
D, I = merge_result_tables(D1, I1, other_D1, other_I1, variant=self.merge_variant) | |
# print(f"after merge {rank=:} {I=:}", file=sys.stderr) | |
return D[0], I[0] | |
def search_job(rank, port, args, ds, res_queue): | |
""" | |
Function that gets called on one process per GPU. | |
Instanciates a database and DistributedSearch object. | |
Performs a few rounds of searching with different queries. | |
Reports timings and verifies results. | |
""" | |
world_size = args.ngpu | |
print(f"Start search_job {rank=:}") | |
os.environ["MASTER_ADDR"] = "localhost" | |
os.environ["MASTER_PORT"] = str(port) | |
init_process_group(backend="nccl", rank=rank, world_size=world_size) | |
torch.cuda.set_device(rank) | |
device = "cuda:%d" % rank | |
if args.nthread != -1: | |
faiss.omp_set_num_threads(args.nthread) | |
if args.cache_dataset: | |
if rank == 0: | |
print("caching queries") | |
ds.get_database(rank, 0) | |
for step in range(args.niter): | |
ds.get_queries(step, rank, 0) | |
# prepare local shard of dataset | |
head = 0 | |
xb = torch.from_numpy(ds.get_database(rank, head)).to(device) | |
distrib = DistibutedSearch(rank, world_size, xb, merge_variant=args.merge_table) | |
print(f"xb ok {rank=:}") | |
all_I = [] | |
times = [] | |
for step in range(args.niter): | |
# this can be slow ==> exclude from timing | |
xq = torch.from_numpy(ds.get_queries(step, rank, head)).to(device) | |
t0 = time.time() | |
if args.merge_type == "sequential": # sequential implem (slow) | |
for q_src in range(world_size): | |
if q_src == rank: | |
D, I = distrib.search(q_src, xq, args.k) | |
else: | |
# because all xq's have the same size | |
distrib.search(q_src, len(xq), args.k) | |
elif args.merge_type == "n2n": # faster implem | |
D, I = distrib.search_n2n(xq, args.k) | |
else: | |
assert False | |
all_I.append(I.cpu().numpy()) | |
ti = time.time() - t0 | |
if rank == 0: | |
print(f"End {step=:} time={ti:.3f} s") | |
times.append(ti) | |
destroy_process_group() | |
# report timings | |
if rank == 0: | |
times = np.array(times) | |
# print("timings: ", times) | |
print(f"mean time on {len(times)-2} iterations: {times[2:].mean():.3f} s") | |
print("checking results....") | |
# verification code | |
xb = np.vstack([ds.get_database(r, head) for r in range(world_size)]) | |
for step in range(args.niter): | |
xq = ds.get_queries(step, rank, head) | |
D, ref_I = faiss.knn_gpu(distrib.res, xq, xb, args.k, device=rank) | |
ndiff = (all_I[step] != ref_I).sum() | |
assert ndiff < ref_I.size * 0.001, f"{ndiff=:} / {ref_I.size}" | |
class Dataset: | |
""" | |
Random dataset that provides database vectors and query vectors in a reproducible way. | |
""" | |
def __init__(self, args): | |
self.args = args | |
enable_cache = args.cache_dataset | |
self.cache_queries = {} if enable_cache else None | |
self.cache_database = {} if enable_cache else None | |
def get_database(self, rank, head): | |
if self.cache_database is not None and (rank, head) in self.cache_database: | |
return self.cache_database[(rank, head)] | |
seed = hash((rank, head)) | |
args = self.args | |
bshape = (args.klen, args.head_dim) | |
x = faiss.rand_smooth_vectors(1, int(np.prod(bshape)), seed=seed).reshape(bshape) | |
if self.cache_database is not None: | |
self.cache_database[(rank, head)] = x | |
return x | |
def get_queries(self, step, rank, head): | |
if self.cache_queries is not None and (step, rank, head) in self.cache_queries: | |
return self.cache_queries[(step, rank, head)] | |
seed = hash((step, rank, head)) | |
args = self.args | |
qshape = (args.bs_qlen, args.head_dim) | |
x = faiss.rand_smooth_vectors(1, int(np.prod(qshape)), seed=seed).reshape(qshape) | |
if self.cache_queries is not None: | |
self.cache_queries[(step, rank, head)] = x | |
return x | |
def main(): | |
parser = argparse.ArgumentParser() | |
def aa(*args, **kwargs): | |
group.add_argument(*args, **kwargs) | |
group = parser.add_argument_group("problem size") | |
aa("--bs_qlen", default=4096, type=int, help="batch size for queries") | |
aa("--klen", default=10000, type=int, help="length of keys") | |
aa("--nhead", default=32, type=int, help="number of heads (not implemented)") | |
aa("--head_dim", default=128, type=int, help="head dimension") | |
aa("--niter", default=10, type=int, help="number of iterations") | |
aa("--k", default=10, type=int, help="nb resutls per query") | |
group = parser.add_argument_group("system parameters") | |
aa("--nthread", default=32, type=int, help="set nb threads") | |
aa("--ngpu", default=8, type=int, help="number of GPUs to use") | |
aa("--merge_type", default="n2n", choices=["sequential", "n2n"]) | |
aa("--cache_dataset", default=False, action="store_true", help="caching of dataset") | |
aa("--merge_table", default="triton", choices=["triton", "torch"], help="") | |
args = parser.parse_args() | |
print(f"Running on {args.ngpu} GPUs") | |
assert torch.cuda.device_count() >= args.ngpu | |
ctx = mp.get_context('spawn') | |
res_queue = ctx.SimpleQueue() | |
ds = Dataset(args) | |
port = np.random.randint(50000, 65000) | |
mp.spawn( | |
search_job, | |
args=(port, args, ds, res_queue), | |
nprocs=args.ngpu, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment