Skip to content

Instantly share code, notes, and snippets.

@zahash
Created December 1, 2020 11:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zahash/4c84734e4fd7c3258da5e0d89b7086a7 to your computer and use it in GitHub Desktop.
Save zahash/4c84734e4fd7c3258da5e0d89b7086a7 to your computer and use it in GitHub Desktop.
def string_input_processor(inputs):
if not inputs:
return
vocabularies = defaultdict(set)
for batch, _ in get_dataset(batch_size=BATCH_SIZE):
for col_name in inputs.keys():
for st in np.array(batch[col_name]).astype("str"):
vocabularies[col_name].add(st.lower().strip())
processed_string_inputs = []
for col_name, col_input in inputs.items():
lookup = tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=list(vocabularies[col_name])
)
one_hot = tf.keras.layers.experimental.preprocessing.CategoryEncoding(
max_tokens=lookup.vocab_size()
)
x = tf.strings.lower(col_input)
x = tf.strings.strip(x)
x = lookup(x)
x = one_hot(x)
processed_string_inputs.append(x)
if len(processed_string_inputs) > 1:
concat = tf.keras.layers.Concatenate()(processed_string_inputs)
return concat
else:
return processed_string_inputs[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment