|
""" |
|
Script for benchmarking performance of matmuls with `SparseTensor` vs `CSRSparseMatrix`. |
|
|
|
Requires tensorflow 2.3 and absl-py |
|
|
|
```bash |
|
pip install tensorflow==2.3 |
|
pip install absl-py |
|
``` |
|
""" |
|
import functools |
|
from typing import Callable |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
from absl import app, flags |
|
from tensorflow.python.ops.linalg.sparse import sparse as sparse_lib |
|
|
|
|
|
flags.DEFINE_integer("n", default=20000, help="number of points per cloud in") |
|
flags.DEFINE_integer("num_features", default=64, help="number of features in") |
|
flags.DEFINE_integer("num_matmuls", default=8, help="number of matmuls") |
|
flags.DEFINE_float("sparsity", default=0.001, help="mean number of edges") |
|
flags.DEFINE_integer("burn", default=10, help="number of burn iterations") |
|
flags.DEFINE_integer("iters", default=100, help="number of iterations to average over") |
|
flags.DEFINE_boolean("jit", default=False, help="XLA jit compilation") |
|
flags.DEFINE_boolean("backward", default=False, help="include backwards pass") |
|
|
|
|
|
def summarize(result, print_fn=print): |
|
""" |
|
Args: |
|
result: output of a tf.test.Benchmark.run_op_benchmark call. |
|
print_fn: print-like function. |
|
""" |
|
print_fn("Wall time (ms): {}".format(result["wall_time"] * 1000)) |
|
gpu_mem = result["extras"].get("allocator_maximum_num_bytes_GPU_0_bfc", 0) |
|
print_fn("Memory (Mb): {}".format(gpu_mem / 1024 ** 2)) |
|
|
|
|
|
def summarize_all(*args, print_fn=print): |
|
""" |
|
Applies `summarize` to (name, result) pairs. |
|
|
|
Args: |
|
*args: (name, result) pairs |
|
print_fn: print-like function. |
|
""" |
|
for name, result in args: |
|
print_fn(name) |
|
summarize(result, print_fn) |
|
|
|
|
|
def get_data( |
|
N: int, sparsity: float, num_features: int, |
|
): |
|
num_edges = int(N * N * sparsity) |
|
while True: |
|
flat_index = np.random.randint( |
|
0, high=N * N, size=num_edges * 2, dtype=np.int64, # extras for duplicates |
|
) |
|
flat_index = np.unique(flat_index) |
|
if len(flat_index) >= num_edges: |
|
flat_index = flat_index[:num_edges] |
|
break |
|
flat_index = np.sort(flat_index) |
|
i, j = np.unravel_index( # pylint: disable=unbalanced-tuple-unpacking |
|
flat_index, (N, N) |
|
) |
|
sparse_indices = tf.constant(np.stack((i, j), axis=-1), dtype=tf.int64) |
|
sparse_values = tf.constant(np.random.normal(size=(num_edges,)), dtype=tf.float32) |
|
dense_shape = [tf.constant(N), tf.constant(N)] |
|
st = tf.SparseTensor(sparse_indices, sparse_values, dense_shape) |
|
features = tf.constant(np.random.normal(size=(N, num_features)), dtype=tf.float32) |
|
return st, features |
|
|
|
|
|
def sparse_matmul( |
|
N: int, sparsity: float, num_features: int, num_matmuls: int, backward: bool |
|
): |
|
st, features = get_data(N, sparsity, num_features) |
|
with tf.GradientTape() as tape: |
|
tape.watch(features) |
|
x = features |
|
for _ in range(num_matmuls): |
|
x = tf.sparse.sparse_dense_matmul(st, x) |
|
if backward: |
|
grad = tape.gradient(x, features) |
|
return x, grad |
|
return x |
|
|
|
|
|
def csr_matmul( |
|
N: int, sparsity: float, num_features: int, num_matmuls: int, backward: bool |
|
): |
|
st, features = get_data(N, sparsity, num_features) |
|
with tf.GradientTape() as tape: |
|
tape.watch(features) |
|
csr = sparse_lib.CSRSparseMatrix(st) |
|
x = features |
|
for _ in range(num_matmuls): |
|
x = sparse_lib.matmul(csr, x) |
|
if backward: |
|
grad = tape.gradient(x, features) |
|
return x, grad |
|
return x |
|
|
|
|
|
def benchmark_matmul( |
|
matmul_fn: Callable, burn_iters: int, min_iters: int, |
|
): |
|
with tf.Graph().as_default() as graph: |
|
output = matmul_fn() |
|
with tf.compat.v1.Session(graph=graph) as sess: |
|
bm = tf.test.Benchmark() |
|
print("Starting benchmarking...") |
|
result = bm.run_op_benchmark( |
|
sess, output, burn_iters=burn_iters, min_iters=min_iters |
|
) |
|
summarize(result) |
|
return result |
|
|
|
|
|
def main(_): |
|
FLAGS = flags.FLAGS |
|
tf.config.optimizer.set_jit(FLAGS.jit) |
|
call_kwargs = dict( |
|
N=FLAGS.n, |
|
sparsity=FLAGS.sparsity, |
|
num_features=FLAGS.num_features, |
|
num_matmuls=FLAGS.num_matmuls, |
|
backward=FLAGS.backward, |
|
) |
|
results = [] |
|
for base_fn in (sparse_matmul, csr_matmul): |
|
name = base_fn.__name__ |
|
matmul_fn = functools.partial(base_fn, **call_kwargs) |
|
result = benchmark_matmul( |
|
matmul_fn, burn_iters=FLAGS.burn, min_iters=FLAGS.iters |
|
) |
|
results.append((name, result)) |
|
summarize_all(*results) |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
On my 12-core laptop with GTX-1050-TI GPU: