Skip to content

Instantly share code, notes, and snippets.

@zmjjmz
Created December 19, 2017 18:23
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save zmjjmz/7637e2713a458ac1f69655af20e38717 to your computer and use it in GitHub Desktop.
Save zmjjmz/7637e2713a458ac1f69655af20e38717 to your computer and use it in GitHub Desktop.
shitty lookup layer
class TokenizeLookupLayer(keras.layers.Layer):
"""
Layer that encapsulates the following:
- Tokenizing sentences by space (or given delimiter)
- Looking up the words with a given vocabulary list / table
- Resetting the shape of the above to be batch_size x pad_len (using dark magic)
# Input Shape
2D string tensor with shape `(batch_size, 1)`
# Output Shape
2D int32 tensor with shape `(batch_size, pad_len)`
"""
def __init__(self, word_ind_map, pad_len, pad_value=0, oov_value=1, **kwargs):
super(TokenizeLookupLayer, self).__init__(**kwargs)
self.input_spec = keras.engine.InputSpec(
ndim=2, dtype='string')
self.pad_len = pad_len
self.pad_value = pad_value
self.oov_value = oov_value
self.word_ind_map = word_ind_map
def get_config(self):
config = {
'word_ind_map': self.word_ind_map,
'pad_len': self.pad_len,
'pad_value': self.pad_value,
'oov_value': self.oov_value,
}
base_config = super(TokenizeLookupLayer, self).get_config()
config.update(base_config)
return config
def build(self, input_shape):
self.lookup_tab = tensorflow.contrib.lookup.HashTable(
tensorflow.contrib.lookup.KeyValueTensorInitializer(
*zip(*self.word_ind_map.iteritems())),
default_value=self.oov_value)
try:
tensorflow.tables_initializer().run(session=keras.backend.get_session())
except tensorflow.errors.FailedPreconditionError:
#TODO(ZJ) this is probably wrong?: DS-209
pass
super(TokenizeLookupLayer, self).build(input_shape)
def call(self, str_inp):
# no name supported for this op?!
tokenized_inp = tensorflow.string_split(
tensorflow.squeeze(str_inp, axis=1))
sparse_inp_lookedup = self.lookup_tab.lookup(
tokenized_inp,
name='lookup'
)
# this could be batch_size x max_seq_len_in_batch
# and max_seq_len_in_batch bears no relation to pad_len, but we need to
# get it out in pad_len
dense_inp = tensorflow.sparse_tensor_to_dense(
sparse_inp_lookedup,
default_value=self.pad_value,
name='dense'
)
# So essentially: add 0s to the end up to pad_len
# pad
pad_full = tensorflow.pad(
dense_inp,
paddings=tensorflow.constant([[0, 0], [0, self.pad_len]]),
#paddings=tensorflow.constant([[0, self.pad_len]]),
mode='CONSTANT',
constant_values=self.pad_value,
name='pad'
)
# Then limit the second dimension to pad_len
# slice
sliced = pad_full[:, :self.pad_len]
return sliced
def compute_output_shape(self, input_shape):
# return (input_shape[0], self.pad_len)
return (input_shape[0], self.pad_len,)
@soaxelbrooke
Copy link

Howdy! I forked this and added a regexp replace before the split, allowing for regexp-based tokenization instead of just delim-based tokenization, in case its useful to you:

https://gist.github.com/soaxelbrooke/246959a7290313fb22be021d9c82a394

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment