Created
June 14, 2018 16:22
-
-
Save patzm/961dcdcafbf3c253a056807c56604628 to your computer and use it in GitHub Desktop.
TensorFlow summarize arbitrary metrics
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def summarize_metrics(metric_ops=None, scope=None, list_lookup=None, name_transform_fun=None): | |
""" | |
Summarizes metrics either provided by `metric_ops` or retrieved from the default metric collection. `scope` filters | |
the operators retrieved from the collection. | |
:param metric_ops: list of metric update operations or variables that contain the metric value | |
:type metric_ops: list | |
:param scope: the name of the collection to extract the metrics from | |
:type scope: str | |
:param list_lookup: keys in the dictionary match with the tensor names that are presented as lists. The values have | |
the same length as the metric tensor. The values are used to annotate the scalar summaries. | |
:type list_lookup: dict | |
:param name_transform_fun: callable that takes a string and returns transformed string. Must guarantee unique names | |
:type name_transform_fun: callable | |
""" | |
if metric_ops is None: | |
print('Summarizing metrics{}'.format(' in scope ' + scope if scope else '')) | |
metric_ops = tf.get_collection('metrics_update_ops', scope=scope) | |
for metric_op in metric_ops: | |
if callable(name_transform_fun): | |
name = name_transform_fun(metric_op.name) | |
else: | |
name = metric_op.name | |
print('Summarizing metric {} as {}'.format(metric_op.name, name)) | |
shape = metric_op.shape.as_list() | |
if shape: | |
found = False | |
for key, val in list_lookup.items(): | |
if len(val) == shape[0] and key in name: | |
found = True | |
summary_components = tf.split(metric_op, shape[0]) | |
for i, identifier in enumerate(val): | |
tf.summary.scalar('/'.join((name, str(identifier))), | |
tf.squeeze(summary_components[i], axis=[0]), | |
family='metrics') | |
break # we found the matching metric lookup | |
if not found: | |
print('Metric {} not summarized: no matching lookup entry was found for operator'.format(name)) | |
else: | |
tf.summary.scalar(name, metric_op, family='metrics') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment