Skip to content

Instantly share code, notes, and snippets.

@awadalaa
Last active April 29, 2021 17:21
Show Gist options
  • Save awadalaa/bcafb5da46ced7d9373f0d51ce389aa3 to your computer and use it in GitHub Desktop.
Save awadalaa/bcafb5da46ced7d9373f0d51ce389aa3 to your computer and use it in GitHub Desktop.
def export_model(self, model):
# make sure we apply TFT transform logic to serving function
feature_forward_keys = (
self.args["keys_to_forward"].split(",")
if self.args["forward_keys"] and self.args["keys_to_forward"]
else []
)
if self.args["attach_prediction_head"]:
full_model = self.attach_prediction_head(model)
else:
full_model = model
full_model.save(
filepath=self.get_model_dir(),
overwrite=True,
signatures=self.get_serve_tf_examples_fn(model, feature_forward_keys),
)
def get_serve_tf_examples_fn(self, model, forwarding_keys=[]):
# Returns a function that parses a serialized tf.Example and applies TFT.
model.tft_layer = self.tft_transform_output.transform_features_layer()
@tf.function
def extract_forwarded_features(raw_features):
forwarded_features = {}
for key in forwarding_keys:
if key not in raw_features:
raise ValueError(
"Forwarded feature {} does not exist! Available features: {}".format(
key, [*raw_features.keys()]
)
)
feature = raw_features[key]
with tf.name_scope("forward_features"):
# Export signatures only take dense tensors
if isinstance(feature, tf.SparseTensor):
feature = tf.sparse.to_dense(feature, name="sparse_to_dense")
# Keeping the export signature for forwarded features the same as the Estimator API
forwarded_features[key] = tf.squeeze(feature, axis=-1)
return forwarded_features
@tf.function
def inference_model(serialized_tf_examples):
# Returns the output to be used in the serving signature.
raw_feature_spec = self.tft_transform_output.raw_feature_spec()
raw_feature_spec.pop(self.get_label_key())
parsed_features = tf.io.parse_example(
serialized_tf_examples, raw_feature_spec
)
transformed_features = model.tft_layer(parsed_features)
forwarded_features = extract_forwarded_features(parsed_features)
return model(transformed_features), forwarded_features
@tf.function
def serving_default_signature(serialized_examples):
logits, forwarded_features = inference_model(serialized_examples)
two_class_logits = tf.concat(
(tf.zeros_like(logits), logits), axis=-1, name="two_class_logits"
)
return {
"scores": tf.keras.layers.Softmax(name="probabilities")(
two_class_logits
),
**forwarded_features,
}
@tf.function
def predict_signature(serialized_examples):
logits, forwarded_features = inference_model(serialized_examples)
two_class_logits = tf.concat(
(tf.zeros_like(logits), logits), axis=-1, name="two_class_logits"
)
return {
"logits": logits,
"logistic": tf.keras.layers.Activation("sigmoid")(logits),
"probabilities": tf.keras.layers.Softmax(name="probabilities")(
two_class_logits
),
**forwarded_features,
}
return {
"serving_default": serving_default_signature.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="inputs")
),
"predict": predict_signature.get_concrete_function(
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples")
),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment