Skip to content

Instantly share code, notes, and snippets.

@girving
Created June 19, 2016 20:10
Show Gist options
  • Save girving/fbe861e6df1ce9d5add8e57bf32a247d to your computer and use it in GitHub Desktop.
Save girving/fbe861e6df1ce9d5add8e57bf32a247d to your computer and use it in GitHub Desktop.
"""Print ops that have float but not double definitions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from google3.net.proto2.python.public import text_format
from google3.pyglib import app
from google3.pyglib import flags
from google3.third_party.tensorflow.core.framework import op_def_pb2
import tensorflow as tf
FLAGS = flags.FLAGS
flags.DEFINE_string('ops', None, 'Path to ops file.')
def main(unused_argv):
ops = op_def_pb2.OpList()
text_format.Merge(open(FLAGS.ops).read(), ops)
op_types = {}
for op in ops.op:
type_attrs = {}
for attr in op.attr:
if attr.type == 'type':
type_attrs[attr.name] = list(attr.allowed_values.list.type)
types = []
for arg in list(op.input_arg) + list(op.output_arg):
if arg.type_attr:
types.extend(type_attrs[arg.type_attr])
else:
types.append(arg.type)
op_types[op.name] = frozenset(types)
#print('%s: %s' % (op.name, types))
float32 = 1
float64 = 2
bad = []
for op, types in op_types.items():
if float32 in types and float64 not in types:
#print('%s: %r' % (op, [tf.DType(t) for t in types]))
bad.append(op)
print(' ' + '\n '.join(sorted(bad)))
if __name__ == '__main__':
app.run()
@yaroslavvb
Copy link

I used --ops tensorflow/core/ops/compat/ops_history.v0.pbtxt

@ibab
Copy link

ibab commented Jul 3, 2016

This can be made to work with the open source version by replacing the imports with

from google.protobuf import text_format
from tensorflow.python.platform import app, flags
from tensorflow.core.framework import op_def_pb2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment