Skip to content

Instantly share code, notes, and snippets.

@laci37
Last active May 21, 2020 20:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save laci37/b7bcb88a08ed09e298b6bf34393ad80b to your computer and use it in GitHub Desktop.
Save laci37/b7bcb88a08ed09e298b6bf34393ad80b to your computer and use it in GitHub Desktop.
Tensorflow dense_sparse_matmul
import tensorflow as tf
def matmul_dense_sparse(sparse_indices, sparse_shape):
"""Multiply a dense tensor with a sparse tensor that has fixed indices and shape.
Custom gradients are defined so the values of the sparse tensor can be optimized.
"""
def result_fun(dense, sparse_values):
ta = tf.transpose(dense)
b = tf.sparse.SparseTensor(sparse_indices, sparse_values, sparse_shape)
tb = tf.sparse.transpose(b)
tb = tf.sparse.SparseTensor(tb.indices, tb.values, (sparse_shape[1], sparse_shape[0]))
res = tf.transpose(tf.sparse.sparse_dense_matmul(tb, ta))
def grad_fn(grad_res):
tgrad = tf.transpose(grad_res)
grad_dense = tf.transpose(tf.sparse.sparse_dense_matmul(b, tgrad))
dense_edge_starts = tf.gather(dense, sparse_indices[:, 0], axis=1)
grad_res_edge_ends = tf.gather(grad_res, sparse_indices[:, 1], axis=1)
grad_values = tf.reduce_sum(tf.multiply(dense_edge_starts, grad_res_edge_ends), axis=0)
return grad_dense, grad_values
return res, grad_fn
return tf.custom_gradient(result_fun)
@kevmanderson
Copy link

Hi laci37, thanks for posting this code! Super useful. I'm working to incorporate something similar into a research project, would you be open to getting in touch? My e-mail is kevin [dot] anderson [at] yale [dot] edu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment