Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Last active December 17, 2018 23:00
Show Gist options
  • Save jcjohnson/b03a0275e64681bb7587bbc7399a645a to your computer and use it in GitHub Desktop.
Save jcjohnson/b03a0275e64681bb7587bbc7399a645a to your computer and use it in GitHub Desktop.
import argparse
import time
import torch
import numpy as np
def int_list(s):
return [int(x) for x in s.split(',')]
parser = argparse.ArgumentParser()
parser.add_argument('--Ns', type=int_list, default=[16, 64, 128, 256, 512])
parser.add_argument('--Ds', type=int_list, default=[1, 16, 256, 1024, 4096])
parser.add_argument('--Ks', type=int_list, default=[1, 16, 256, 1024, 4096])
parser.add_argument('--tolerance', type=float, default=1e-8)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cpu')
def main(args):
verbose = args.verbose
tol = args.tolerance
with_backward = args.with_backward
num_experiments = len(args.Ns) * len(args.Ds) * len(args.Ks)
f_better = []
g_better = []
for device in [args.device]:
print('Benchmarking with device = %s' % device)
sames, f_speedups, b_speedups = [], [], []
rows = []
deltas = []
i = 0
for N in args.Ns:
for D in args.Ds:
for K in args.Ks:
i += 1
if i % 10 == 0:
print('Running experiment %d / %d' % (i, num_experiments))
same, f_speedup, f_time_us, g_time_us = benchmark(N, D, K, tol, with_backward, device, verbose)
if f_speedup < 1.0:
f_better.append((N, D, K, f_time_us, g_time_us))
else:
g_better.append((N, D, K, f_time_us, g_time_us))
rows.append((N, D, K, f_time_us, g_time_us))
deltas.append(f_time_us - g_time_us)
sames.append(same)
f_speedups.append(f_speedup)
print()
print('Results with device = %s' % device)
print('Differences within tolerance (%f)' % tol, all(sames))
print('Forward gather speedup:')
print(' Min: ', np.min(f_speedups))
print(' Max: ', np.max(f_speedups))
print(' Mean: ', np.mean(f_speedups))
print(' Median: ', np.median(f_speedups))
total = len(f_better) + len(g_better)
print('Test cases with faster indexing: {}/{}'.format(len(f_better), total))
for row in f_better:
print('N: {} D: {} K: {} index: {:4.0f} us gather: {:4.0f}'.format(*row))
print('Test cases with faster gather: {}/{}'.format(len(g_better), total))
for row in g_better:
print('N: {} D: {} K: {} index: {:4.0f} us gather: {:4.0f}'.format(*row))
idx_fastest = np.argmin(deltas)
print('Indexing is faster by at most {} us on N: {} D: {} K: {}'.format(
-deltas[idx_fastest],
*rows[idx_fastest][:3]))
idx_slowest = np.argmax(deltas)
print('Indexing is slower by at most {} us on N: {} D: {} K: {}'.format(
deltas[idx_slowest],
*rows[idx_slowest][:3]))
def timeit(f, x, idx, with_backward=False):
if x.grad is not None:
x.grad.data.zero_()
t0 = time.time()
y = f(x, idx)
if with_backward:
x.requires_grad = True
dy = torch.ones_like(y)
y.backward(gradient=dy)
delta = 1000.0 * (time.time() - t0)
# spend ~100 ms benchmarking
iters = max(1, int(100.0 / delta))
if x.is_cuda:
torch.cuda.synchronize()
t0 = time.time()
for _ in range(iters):
y = f(x, idx)
if with_backward:
y.backward(gradient=dy)
if x.is_cuda:
torch.cuda.synchronize()
t1 = time.time()
# in microseconds
t_us = 1000000.0 * (t1 - t0) / iters
return y, t_us
def benchmark(N, D, K, tol, with_backward, device='cuda', verbose=False):
index_times, gather_times = [], []
y_diffs = []
for _ in range(1):
x = torch.randn(N, D, requires_grad=True, device=device)
idx = torch.randint(N, size=(K,), device=device)
# Time forward / backward for index
y_index, t_index = timeit(index, x, idx, with_backward)
y_gather, t_gather = timeit(gather, x, idx, with_backward)
index_times.append(t_index)
gather_times.append(t_gather)
with torch.no_grad():
y_diff = (y_index - y_gather).abs().sum()
y_diffs.append(y_diff.item())
y_diff = np.max(y_diffs)
t_index = np.mean(index_times)
t_gather = np.mean(gather_times)
same = y_diff < tol
speedup = t_index / t_gather
return same, speedup, t_index, t_gather
def index(x, idx):
return x[idx]
def gather(x, idx):
idx = idx[:, None].expand(idx.shape[0], x.shape[1])
return x.gather(0, idx)
if __name__ == '__main__':
main(parser.parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment