Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Last active November 5, 2018 21:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BarclayII/6cf6c5acf1d4ecf92bbda3dbb4636662 to your computer and use it in GitHub Desktop.
Save BarclayII/6cf6c5acf1d4ecf92bbda3dbb4636662 to your computer and use it in GitHub Desktop.
PyTorch gather-scatter/SPMV benchmarks
import torch
import time
N = 10000
D = 50
E = 500000
T = 10
t_gather = 0
t_scatter = 0
t_out_incident = 0
t_in_incident = 0
t_adjacent = 0
t_time_overhead = 0
for i in range(10000):
t0 = time.time()
t_time_overhead += time.time() - t0
t_time_overhead /= 10000
for i in range(T):
x = torch.randn(N, D).cuda()
src = (torch.rand(E).cuda() * N).long()
dst = (torch.rand(E).cuda() * N).long()
out_degree = torch.zeros(N).cuda().long().scatter_add(0, src, torch.ones_like(src))
in_degree = torch.zeros(N).cuda().long().scatter_add(0, dst, torch.ones_like(dst))
max_in_degree = in_degree.max()
t0 = time.time()
y = x.gather(0, src[:, None].expand(src.shape[0], *x.shape[1:]))
z = torch.zeros_like(x).scatter_add(0, dst[:, None].expand(dst.shape[0], *x.shape[1:]), y)
torch.cuda.synchronize()
t_scatter += time.time() - t0
t0 = time.time()
out_incident_coo = torch.stack([torch.arange(E).cuda(), src], 0)
out_incident = torch.cuda.sparse.FloatTensor(out_incident_coo, torch.ones(E).cuda(), (E, N))
y = torch.spmm(out_incident, x)
in_incident_coo = torch.stack([dst, torch.arange(E).cuda()], 0)
in_incident = torch.cuda.sparse.FloatTensor(in_incident_coo, torch.ones(E).cuda(), (N, E))
z = torch.spmm(in_incident, y)
torch.cuda.synchronize()
t_in_incident += time.time() - t0
t0 = time.time()
adj_coo = torch.stack([dst, src], 0)
adj = torch.cuda.sparse.FloatTensor(adj_coo, torch.ones(E).cuda(), (N, N))
z = torch.spmm(adj, x)
torch.cuda.synchronize()
t_adjacent += time.time() - t0
t_gather /= T
t_scatter /= T
t_out_incident /= T
t_in_incident /= T
t_adjacent /= T
print(t_scatter, t_in_incident, t_adjacent)
# Result:
# 0.015599870681762695 0.15931901931762696 0.03305981159210205
# Even if I moved the sparse FloatTensor constructions out completely, the result is
# 0.015459918975830078 0.12213940620422363 0.0214599609375
# ????????
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment