Last active
May 21, 2020 20:22
-
-
Save laci37/b7bcb88a08ed09e298b6bf34393ad80b to your computer and use it in GitHub Desktop.
Tensorflow dense_sparse_matmul
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def matmul_dense_sparse(sparse_indices, sparse_shape): | |
"""Multiply a dense tensor with a sparse tensor that has fixed indices and shape. | |
Custom gradients are defined so the values of the sparse tensor can be optimized. | |
""" | |
def result_fun(dense, sparse_values): | |
ta = tf.transpose(dense) | |
b = tf.sparse.SparseTensor(sparse_indices, sparse_values, sparse_shape) | |
tb = tf.sparse.transpose(b) | |
tb = tf.sparse.SparseTensor(tb.indices, tb.values, (sparse_shape[1], sparse_shape[0])) | |
res = tf.transpose(tf.sparse.sparse_dense_matmul(tb, ta)) | |
def grad_fn(grad_res): | |
tgrad = tf.transpose(grad_res) | |
grad_dense = tf.transpose(tf.sparse.sparse_dense_matmul(b, tgrad)) | |
dense_edge_starts = tf.gather(dense, sparse_indices[:, 0], axis=1) | |
grad_res_edge_ends = tf.gather(grad_res, sparse_indices[:, 1], axis=1) | |
grad_values = tf.reduce_sum(tf.multiply(dense_edge_starts, grad_res_edge_ends), axis=0) | |
return grad_dense, grad_values | |
return res, grad_fn | |
return tf.custom_gradient(result_fun) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi laci37, thanks for posting this code! Super useful. I'm working to incorporate something similar into a research project, would you be open to getting in touch? My e-mail is kevin [dot] anderson [at] yale [dot] edu