Skip to content

Instantly share code, notes, and snippets.

@alreadytaikeune
Created August 13, 2018 13:13
Show Gist options
  • Save alreadytaikeune/ac12bf3cbdc5a0365c065ef0028b4ebf to your computer and use it in GitHub Desktop.
Save alreadytaikeune/ac12bf3cbdc5a0365c065ef0028b4ebf to your computer and use it in GitHub Desktop.
from tensorflow.tools.graph_transforms import TransformGraph
def transform_graph(graph_def, input_node_names, output_node_names):
transforms = [
"strip_unused_nodes",
"remove_nodes(op=Identity, op=CheckNumerics)",
"fold_constants(ignore_errors=true)",
# remove colocation attribute
"remove_attribute(attribute_name=_class)",
"fold_batch_norms",
"fold_old_batch_norms",
"quantize_weights",
"obfuscate_names"]
transformed_graph_def = TransformGraph(
graph_def, input_node_names, output_node_names, transforms)
return transformed_graph_def
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment