Skip to content

Instantly share code, notes, and snippets.

@patzm
Created June 14, 2018 16:22
Show Gist options
  • Save patzm/961dcdcafbf3c253a056807c56604628 to your computer and use it in GitHub Desktop.
Save patzm/961dcdcafbf3c253a056807c56604628 to your computer and use it in GitHub Desktop.
TensorFlow summarize arbitrary metrics
def summarize_metrics(metric_ops=None, scope=None, list_lookup=None, name_transform_fun=None):
"""
Summarizes metrics either provided by `metric_ops` or retrieved from the default metric collection. `scope` filters
the operators retrieved from the collection.
:param metric_ops: list of metric update operations or variables that contain the metric value
:type metric_ops: list
:param scope: the name of the collection to extract the metrics from
:type scope: str
:param list_lookup: keys in the dictionary match with the tensor names that are presented as lists. The values have
the same length as the metric tensor. The values are used to annotate the scalar summaries.
:type list_lookup: dict
:param name_transform_fun: callable that takes a string and returns transformed string. Must guarantee unique names
:type name_transform_fun: callable
"""
if metric_ops is None:
print('Summarizing metrics{}'.format(' in scope ' + scope if scope else ''))
metric_ops = tf.get_collection('metrics_update_ops', scope=scope)
for metric_op in metric_ops:
if callable(name_transform_fun):
name = name_transform_fun(metric_op.name)
else:
name = metric_op.name
print('Summarizing metric {} as {}'.format(metric_op.name, name))
shape = metric_op.shape.as_list()
if shape:
found = False
for key, val in list_lookup.items():
if len(val) == shape[0] and key in name:
found = True
summary_components = tf.split(metric_op, shape[0])
for i, identifier in enumerate(val):
tf.summary.scalar('/'.join((name, str(identifier))),
tf.squeeze(summary_components[i], axis=[0]),
family='metrics')
break # we found the matching metric lookup
if not found:
print('Metric {} not summarized: no matching lookup entry was found for operator'.format(name))
else:
tf.summary.scalar(name, metric_op, family='metrics')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment