Created
December 1, 2020 11:20
-
-
Save zahash/4c84734e4fd7c3258da5e0d89b7086a7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def string_input_processor(inputs): | |
if not inputs: | |
return | |
vocabularies = defaultdict(set) | |
for batch, _ in get_dataset(batch_size=BATCH_SIZE): | |
for col_name in inputs.keys(): | |
for st in np.array(batch[col_name]).astype("str"): | |
vocabularies[col_name].add(st.lower().strip()) | |
processed_string_inputs = [] | |
for col_name, col_input in inputs.items(): | |
lookup = tf.keras.layers.experimental.preprocessing.StringLookup( | |
vocabulary=list(vocabularies[col_name]) | |
) | |
one_hot = tf.keras.layers.experimental.preprocessing.CategoryEncoding( | |
max_tokens=lookup.vocab_size() | |
) | |
x = tf.strings.lower(col_input) | |
x = tf.strings.strip(x) | |
x = lookup(x) | |
x = one_hot(x) | |
processed_string_inputs.append(x) | |
if len(processed_string_inputs) > 1: | |
concat = tf.keras.layers.Concatenate()(processed_string_inputs) | |
return concat | |
else: | |
return processed_string_inputs[0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment