Skip to content

Instantly share code, notes, and snippets.

@yuq-1s
Created May 15, 2018 12:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yuq-1s/ce63a306f1d39d1c0c80d33f7855f3b5 to your computer and use it in GitHub Desktop.
Save yuq-1s/ce63a306f1d39d1c0c80d33f7855f3b5 to your computer and use it in GitHub Desktop.
Find first n appearances of tensors containing Nans for TensorFlow.
#! /usr/bin/env python3
import os
import tensorflow as tf
import numpy as np
import tempfile
from tensorflow.python import debug as tfdbg
def watch_session(dump_root_dir, train_op, times):
''' Dump tensors to temperary directory for analysis'''
watch_opt = tfdbg.WatchOptions(
debug_ops="DebugIdentity", node_name_regex_whitelist=r".*")
def my_watch_fn(fetches, feeds):
return watch_opt
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess = tfdbg.DumpingDebugWrapperSession(sess, watch_fn=my_watch_fn,
session_root=dump_root_dir)
for _ in range(times):
sess.run(train_op)
def next_nan_or_inf(dump_root_dir):
for folder in sorted(os.listdir(dump_root_dir)):
mydir = tfdbg.DebugDumpDir(os.path.join(dump_root_dir, folder))
for data in sorted(mydir.dumped_tensor_data,
key=lambda x: x.timestamp):
try:
tensor = data.get_tensor()
if 'float' in tensor.dtype.name:
if np.isnan(tensor).any() or np.isinf(tensor).any():
yield data
except AttributeError:
# Some tensors may be uninitialized, we skip those
pass
if __name__ == __main__:
NUM_RUNS = 3
NUM_PRINT_TENSORS = 10
train_op = tf.constant([float('NaN')])
dump_root_dir = tempfile.mkdtemp()
print('dump_root_dir: %s' % dump_root_dir)
watch_session(dump_root_dir, train_op, NUM_RUNS)
count = 0
for data in next_nan_or_inf(dump_root_dir):
print('-'*40)
print('[*] Round %s: %s' % (
data.file_path.split('/')[3][-1], data.tensor_name))
count += 1
if count > NUM_PRINT_TENSORS:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment