def numerical_input_processor(inputs): if not inputs: return concat = None if len(inputs.values()) > 1: concat = tf.keras.layers.Concatenate()(list(inputs.values())) norm = tf.keras.layers.experimental.preprocessing.Normalization() for batch, _ in get_dataset(batch_size=DUMMY_BATCH_SIZE).take(1): data = [] for k in inputs.keys(): data.append(np.array(batch[k])) data = np.array(data) data = np.transpose(data) norm.adapt(data) for batch, _ in get_dataset(batch_size=BATCH_SIZE): data = [] for k in inputs.keys(): data.append(np.array(batch[k])) data = np.array(data) data = np.transpose(data) norm.adapt(data, reset_state=False) if concat is not None: numeric_layer = norm(concat) else: numeric_layer = norm(list(inputs.values())[0]) return numeric_layer