Skip to content

Instantly share code, notes, and snippets.

@jackd
Last active September 22, 2020 08:39
Show Gist options
  • Save jackd/93cf681c67e48c1a25de01bf860c6ab2 to your computer and use it in GitHub Desktop.
Save jackd/93cf681c67e48c1a25de01bf860c6ab2 to your computer and use it in GitHub Desktop.
SparseTensor vs CSRSparseMatrix matmul benchmark

Sparse Mamul Benchmarks

Script for benchmarking matmuls with SparseTensor vs CSRSparseMatrix.

Installation

pip install tensorflow==2.3
pip install absl-py
git clone https://gist.github.com/jackd/93cf681c67e48c1a25de01bf860c6ab2
cd 93cf681c67e48c1a25de01bf860c6ab2
python sparse_benchmark.py
"""
Script for benchmarking performance of matmuls with `SparseTensor` vs `CSRSparseMatrix`.
Requires tensorflow 2.3 and absl-py
```bash
pip install tensorflow==2.3
pip install absl-py
```
"""
import functools
from typing import Callable
import numpy as np
import tensorflow as tf
from absl import app, flags
from tensorflow.python.ops.linalg.sparse import sparse as sparse_lib
flags.DEFINE_integer("n", default=20000, help="number of points per cloud in")
flags.DEFINE_integer("num_features", default=64, help="number of features in")
flags.DEFINE_integer("num_matmuls", default=8, help="number of matmuls")
flags.DEFINE_float("sparsity", default=0.001, help="mean number of edges")
flags.DEFINE_integer("burn", default=10, help="number of burn iterations")
flags.DEFINE_integer("iters", default=100, help="number of iterations to average over")
flags.DEFINE_boolean("jit", default=False, help="XLA jit compilation")
flags.DEFINE_boolean("backward", default=False, help="include backwards pass")
def summarize(result, print_fn=print):
"""
Args:
result: output of a tf.test.Benchmark.run_op_benchmark call.
print_fn: print-like function.
"""
print_fn("Wall time (ms): {}".format(result["wall_time"] * 1000))
gpu_mem = result["extras"].get("allocator_maximum_num_bytes_GPU_0_bfc", 0)
print_fn("Memory (Mb): {}".format(gpu_mem / 1024 ** 2))
def summarize_all(*args, print_fn=print):
"""
Applies `summarize` to (name, result) pairs.
Args:
*args: (name, result) pairs
print_fn: print-like function.
"""
for name, result in args:
print_fn(name)
summarize(result, print_fn)
def get_data(
N: int, sparsity: float, num_features: int,
):
num_edges = int(N * N * sparsity)
while True:
flat_index = np.random.randint(
0, high=N * N, size=num_edges * 2, dtype=np.int64, # extras for duplicates
)
flat_index = np.unique(flat_index)
if len(flat_index) >= num_edges:
flat_index = flat_index[:num_edges]
break
flat_index = np.sort(flat_index)
i, j = np.unravel_index( # pylint: disable=unbalanced-tuple-unpacking
flat_index, (N, N)
)
sparse_indices = tf.constant(np.stack((i, j), axis=-1), dtype=tf.int64)
sparse_values = tf.constant(np.random.normal(size=(num_edges,)), dtype=tf.float32)
dense_shape = [tf.constant(N), tf.constant(N)]
st = tf.SparseTensor(sparse_indices, sparse_values, dense_shape)
features = tf.constant(np.random.normal(size=(N, num_features)), dtype=tf.float32)
return st, features
def sparse_matmul(
N: int, sparsity: float, num_features: int, num_matmuls: int, backward: bool
):
st, features = get_data(N, sparsity, num_features)
with tf.GradientTape() as tape:
tape.watch(features)
x = features
for _ in range(num_matmuls):
x = tf.sparse.sparse_dense_matmul(st, x)
if backward:
grad = tape.gradient(x, features)
return x, grad
return x
def csr_matmul(
N: int, sparsity: float, num_features: int, num_matmuls: int, backward: bool
):
st, features = get_data(N, sparsity, num_features)
with tf.GradientTape() as tape:
tape.watch(features)
csr = sparse_lib.CSRSparseMatrix(st)
x = features
for _ in range(num_matmuls):
x = sparse_lib.matmul(csr, x)
if backward:
grad = tape.gradient(x, features)
return x, grad
return x
def benchmark_matmul(
matmul_fn: Callable, burn_iters: int, min_iters: int,
):
with tf.Graph().as_default() as graph:
output = matmul_fn()
with tf.compat.v1.Session(graph=graph) as sess:
bm = tf.test.Benchmark()
print("Starting benchmarking...")
result = bm.run_op_benchmark(
sess, output, burn_iters=burn_iters, min_iters=min_iters
)
summarize(result)
return result
def main(_):
FLAGS = flags.FLAGS
tf.config.optimizer.set_jit(FLAGS.jit)
call_kwargs = dict(
N=FLAGS.n,
sparsity=FLAGS.sparsity,
num_features=FLAGS.num_features,
num_matmuls=FLAGS.num_matmuls,
backward=FLAGS.backward,
)
results = []
for base_fn in (sparse_matmul, csr_matmul):
name = base_fn.__name__
matmul_fn = functools.partial(base_fn, **call_kwargs)
result = benchmark_matmul(
matmul_fn, burn_iters=FLAGS.burn, min_iters=FLAGS.iters
)
results.append((name, result))
summarize_all(*results)
if __name__ == "__main__":
app.run(main)
@jackd
Copy link
Author

jackd commented Sep 22, 2020

On my 12-core laptop with GTX-1050-TI GPU:

>> python sparse_benchmark.py
sparse_matmul
Wall time (ms): 30.946969985961914
Memory (Mb):    17.395034790039062
csr_matmul
Wall time (ms): 9.126067161560059
Memory (Mb):    12.512222290039062

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment