Skip to content

Instantly share code, notes, and snippets.

@zahash
Created December 1, 2020 11:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zahash/b4b1631894a5f0fdb01f284b466a646f to your computer and use it in GitHub Desktop.
Save zahash/b4b1631894a5f0fdb01f284b466a646f to your computer and use it in GitHub Desktop.
def numerical_input_processor(inputs):
if not inputs:
return
concat = None
if len(inputs.values()) > 1:
concat = tf.keras.layers.Concatenate()(list(inputs.values()))
norm = tf.keras.layers.experimental.preprocessing.Normalization()
for batch, _ in get_dataset(batch_size=DUMMY_BATCH_SIZE).take(1):
data = []
for k in inputs.keys():
data.append(np.array(batch[k]))
data = np.array(data)
data = np.transpose(data)
norm.adapt(data)
for batch, _ in get_dataset(batch_size=BATCH_SIZE):
data = []
for k in inputs.keys():
data.append(np.array(batch[k]))
data = np.array(data)
data = np.transpose(data)
norm.adapt(data, reset_state=False)
if concat is not None:
numeric_layer = norm(concat)
else:
numeric_layer = norm(list(inputs.values())[0])
return numeric_layer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment