Skip to content

Instantly share code, notes, and snippets.

@pavanky
Created August 21, 2019 18:21
Show Gist options
  • Save pavanky/3e1635edb322c81436d66ac9d9bdf69f to your computer and use it in GitHub Desktop.
Save pavanky/3e1635edb322c81436d66ac9d9bdf69f to your computer and use it in GitHub Desktop.
keras sparse example
import numpy as np
import tensorflow.compat.v2 as tf
class Sparse(tf.keras.layers.Dense):
def call(self, inputs):
outputs = tf.sparse.sparse_dense_matmul(inputs, self.kernel)
if self.use_bias:
outputs = tf.nn.bias_add(outputs, self.bias)
return outputs
def dummy_parse_fn(iterable):
features = {}
# the input is always constant
sparse_tensor = tf.SparseTensor(
indices=tf.constant([[0,0],[1,1]], dtype=tf.int64),
values=tf.constant([1.0, 1.0], dtype=tf.float32),
dense_shape=tf.constant([2, 2], dtype=tf.int64))
labels = tf.constant([1.0, 1.0], dtype=tf.float32)
return sparse_tensor, labels
def get_dummy_dataset():
iterable = np.random.random((128, 1)).astype(np.float32)
return (
tf.data.Dataset
.from_tensor_slices(iterable)
.map(dummy_parse_fn)
.take(1024)
)
class SparseModel(tf.keras.Model):
def __init__(self):
super(SparseModel, self).__init__()
self._sparse_layer = Sparse(1)
inputs = tf.keras.layers.Input(shape=(2, ), sparse=True, name="sparse_tensor")
self._set_inputs(inputs)
def call(self, sparse_tensor):
sparse_tensor = tf.sparse.SparseTensor(
indices=sparse_tensor.indices,
values=sparse_tensor.values,
dense_shape=[2, 2])
return self._sparse_layer(sparse_tensor)
if __name__ == "__main__":
print(tf.__version__)
model = SparseModel()
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9, nesterov=True)
loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
model.fit(get_dummy_dataset(), epochs=2)
@pavanky
Copy link
Author

pavanky commented Sep 5, 2019

@naiveHobo this only works with 1.15.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment