Skip to content

Instantly share code, notes, and snippets.

@jithinjees
Last active March 21, 2024 04:07
Show Gist options
  • Save jithinjees/a99e57af3812be2c84bdc2ef84ad0de6 to your computer and use it in GitHub Desktop.
Save jithinjees/a99e57af3812be2c84bdc2ef84ad0de6 to your computer and use it in GitHub Desktop.
tensorflow 2.2 code for using lookup tables
import tensorflow as tf
print('tensorflow version ',tf.__version__)
##tensorflow version = 2.2 (also works with tensorflow 2.1)
##This is a simple sample code to use table lookup in tensorflow using 2 different options
##1st method is a file backed table lookup & 2nd one is based on an in memory list
vocab_path='vocab_test.txt'
model_dir_lookup='model/lookup'
model_dir_lookup2 = model_dir_lookup+'_2'
##file based lookup
class VocabLookup(tf.keras.layers.Layer):
def __init__(self,vocab_path):
super(VocabLookup, self).__init__(trainable=False,dtype=tf.int64)
self.vocab_path = vocab_path
def build(self,input_shape):
table_init = tf.lookup.TextFileInitializer(self.vocab_path,tf.string,tf.lookup.TextFileIndex.WHOLE_LINE,
tf.int64,tf.lookup.TextFileIndex.LINE_NUMBER)
self.table = tf.lookup.StaticHashTable(table_init,-1)
self.built=True
def call(self, input_text):
splitted_text = tf.strings.split(input_text).to_tensor()
word_ids = self.table.lookup(splitted_text)
return word_ids
def get_config(self):
config = super(VocabLookup, self).get_config()
config.update({'vocab_path': self.vocab_path})
return config
#list based lookup
class VocabLookup2(tf.keras.layers.Layer):
def __init__(self):
super(VocabLookup2, self).__init__(trainable=False,dtype=tf.int32)
def build(self,input_shape):
self.keys=['hi','testing','lookup','in','tf']
##keeping values to start from 1 instead of zero to be consistent with the file based approach
values=range(1,len(self.keys)+1)
table_init = tf.lookup.KeyValueTensorInitializer(keys=self.keys,values=values)
self.table = tf.lookup.StaticHashTable(table_init,-1)
self.built=True
def call(self, input_text):
splitted_text = tf.strings.split(input_text).to_tensor()
word_ids = self.table.lookup(splitted_text)
return word_ids
def get_config(self):
config = super(VocabLookup2, self).get_config()
config.update({'keys': self.keys})
return config
input_text = tf.keras.Input(shape=(),dtype=tf.string,name='input_text')
lookup_out = VocabLookup(vocab_path=vocab_path)(input_text)
lookup_out2 = VocabLookup2()(input_text)
model_lookup = tf.keras.Model(inputs={'input_text':input_text},outputs=lookup_out)
model_lookup2 = tf.keras.Model(inputs={'input_text':input_text},outputs=lookup_out2)
print('predict from model1 ', model_lookup.predict(['hi testing lookup in tf randomtext']))
print('predict from model2 ',model_lookup2.predict(['hi testing lookup in tf randomtext']))
model_lookup.save(model_dir_lookup)
model_lookup_loaded = tf.keras.models.load_model(model_dir_lookup)
print('loaded model config 1 - \n',model_lookup_loaded.get_config(),'\n')
print('predict from loaded model1 ',model_lookup_loaded.predict(['hi testing lookup in tf randomtext']))
model_lookup2.save(model_dir_lookup2)
model_lookup_loaded2 = tf.keras.models.load_model(model_dir_lookup2)
print('loaded model config 2 - \n',model_lookup_loaded2.get_config(),'\n')
print('predict from loaded model2 ', model_lookup_loaded2.predict(['hi testing lookup in tf randomtext']))
hi
testing
lookup
in
tf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment