Last active
April 29, 2021 17:21
-
-
Save awadalaa/bcafb5da46ced7d9373f0d51ce389aa3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def export_model(self, model): | |
# make sure we apply TFT transform logic to serving function | |
feature_forward_keys = ( | |
self.args["keys_to_forward"].split(",") | |
if self.args["forward_keys"] and self.args["keys_to_forward"] | |
else [] | |
) | |
if self.args["attach_prediction_head"]: | |
full_model = self.attach_prediction_head(model) | |
else: | |
full_model = model | |
full_model.save( | |
filepath=self.get_model_dir(), | |
overwrite=True, | |
signatures=self.get_serve_tf_examples_fn(model, feature_forward_keys), | |
) | |
def get_serve_tf_examples_fn(self, model, forwarding_keys=[]): | |
# Returns a function that parses a serialized tf.Example and applies TFT. | |
model.tft_layer = self.tft_transform_output.transform_features_layer() | |
@tf.function | |
def extract_forwarded_features(raw_features): | |
forwarded_features = {} | |
for key in forwarding_keys: | |
if key not in raw_features: | |
raise ValueError( | |
"Forwarded feature {} does not exist! Available features: {}".format( | |
key, [*raw_features.keys()] | |
) | |
) | |
feature = raw_features[key] | |
with tf.name_scope("forward_features"): | |
# Export signatures only take dense tensors | |
if isinstance(feature, tf.SparseTensor): | |
feature = tf.sparse.to_dense(feature, name="sparse_to_dense") | |
# Keeping the export signature for forwarded features the same as the Estimator API | |
forwarded_features[key] = tf.squeeze(feature, axis=-1) | |
return forwarded_features | |
@tf.function | |
def inference_model(serialized_tf_examples): | |
# Returns the output to be used in the serving signature. | |
raw_feature_spec = self.tft_transform_output.raw_feature_spec() | |
raw_feature_spec.pop(self.get_label_key()) | |
parsed_features = tf.io.parse_example( | |
serialized_tf_examples, raw_feature_spec | |
) | |
transformed_features = model.tft_layer(parsed_features) | |
forwarded_features = extract_forwarded_features(parsed_features) | |
return model(transformed_features), forwarded_features | |
@tf.function | |
def serving_default_signature(serialized_examples): | |
logits, forwarded_features = inference_model(serialized_examples) | |
two_class_logits = tf.concat( | |
(tf.zeros_like(logits), logits), axis=-1, name="two_class_logits" | |
) | |
return { | |
"scores": tf.keras.layers.Softmax(name="probabilities")( | |
two_class_logits | |
), | |
**forwarded_features, | |
} | |
@tf.function | |
def predict_signature(serialized_examples): | |
logits, forwarded_features = inference_model(serialized_examples) | |
two_class_logits = tf.concat( | |
(tf.zeros_like(logits), logits), axis=-1, name="two_class_logits" | |
) | |
return { | |
"logits": logits, | |
"logistic": tf.keras.layers.Activation("sigmoid")(logits), | |
"probabilities": tf.keras.layers.Softmax(name="probabilities")( | |
two_class_logits | |
), | |
**forwarded_features, | |
} | |
return { | |
"serving_default": serving_default_signature.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="inputs") | |
), | |
"predict": predict_signature.get_concrete_function( | |
tf.TensorSpec(shape=[None], dtype=tf.string, name="examples") | |
), | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment