Skip to content

Instantly share code, notes, and snippets.

@danielrenshaw
Created May 29, 2015 07:09
Show Gist options
  • Save danielrenshaw/8fd71250f9cba5a530c2 to your computer and use it in GitHub Desktop.
Save danielrenshaw/8fd71250f9cba5a530c2 to your computer and use it in GitHub Desktop.
Theano diff (from a4e182d) for altering theano/compile/debugmode.py to enable NaN and inf checks during debugprinting
16c17,18
---
> import re
>
517c519,521
< scan_ops=None, profile=None):
---
> scan_ops=None, profile=None, include_nan_info=False,
> include_inf_info=False, recursion_rules='ALWAYS',
> print_test_value=False):
558c563,586
---
> recursion_rules = set([recursion_rule for recursion_rule in recursion_rules.split(',') if len(recursion_rule) > 0])
>
> def nan_inf_info(prefix, enabled, checker):
> if r is not None and enabled:
> if hasattr(r, 'tag') and r.tag is not None and hasattr(r.tag, 'test_value'):
> if isinstance(r.tag.test_value, numpy.ndarray):
> mask = checker(r.tag.test_value)
>
> if mask.all():
> return ' <%s: ALL>' % prefix, 'ALL_' + prefix in recursion_rules
> elif mask.any():
> return ' <%s: SOME>' % prefix, 'SOME_' + prefix in recursion_rules
>
> return ' <%s: NONE>' % prefix, 'NO_' + prefix in recursion_rules
>
> return ' <%s: NOT_NDARRAY>' % prefix, 'NO_TEST_VALUE' in recursion_rules
>
> return ' <%s: NO_TEST_VALUE>' % prefix, 'NO_TEST_VALUE' in recursion_rules
>
> return '', 'NO_TEST_VALUE' in recursion_rules
>
> nan_info, nan_recurse = nan_inf_info('NANS', include_nan_info, numpy.isnan)
> inf_info, inf_recurse = nan_inf_info('INFS', include_inf_info, numpy.isinf)
>
575c604,608
---
> if r is not None and print_test_value and hasattr(r, 'tag') and r.tag is not None and hasattr(r.tag, 'test_value'):
> test_value = ' %s %s' % (r.tag.test_value.shape, re.sub('\\s+', ' ', repr(r.tag.test_value)))
> else:
> test_value = ''
>
610,616c643,651
< print('%s%s %s%s \'%s\' %s %s %s' % (prefix, a.op,
< id_str,
< type_str,
< r_name,
< destroy_map_str,
< view_map_str,
< o), file=file)
---
> print('%s%s %s%s%s%s \'%s\' %s %s %s%s' % (prefix, a.op,
> id_str,
> type_str,
> nan_info,
> inf_info,
> r_name,
> destroy_map_str,
> view_map_str,
> o, test_value), file=file)
618,624c653,661
< print('%s%s.%i %s%s \'%s\' %s %s %s' % (prefix, a.op,
< a.outputs.index(r),
< id_str, type_str,
< r_name,
< destroy_map_str,
< view_map_str,
< o), file=file)
---
> print('%s%s.%i %s%s%s%s \'%s\' %s %s %s%s' % (prefix, a.op,
> a.outputs.index(r),
> id_str, type_str,
> nan_info,
> inf_info,
> r_name,
> destroy_map_str,
> view_map_str,
> o, test_value), file=file)
633,644c670,682
< print("%s%s %s%s '%s' %s %s %s --> "
< "%8.2es %4.1f%% %8.2es %4.1f%%"
< % (prefix, a.op,
< id_str,
< type_str,
< r_name,
< destroy_map_str,
< view_map_str,
< o, op_time,
< op_time_percent,
< tot_time,
< tot_time_percent), file=file)
---
> print('%s%s %s%s%s%s \'%s\' %s %s %s%s --> %8.2es %4.1f%% %8.2es %4.1f%%'\
> % (prefix, a.op,
> id_str,
> type_str,
> nan_info,
> inf_info,
> r_name,
> destroy_map_str,
> view_map_str,
> o, test_value, op_time,
> op_time_percent,
> tot_time,
> tot_time_percent), file=file)
646,657c684,696
< print("%s%s.%i %s%s '%s' %s %s %s --> "
< "%8.2es %4.1f%% %8.2es %4.1f%%"
< % (prefix, a.op,
< a.outputs.index(r),
< id_str, type_str,
< r_name,
< destroy_map_str,
< view_map_str,
< o, op_time,
< op_time_percent,
< tot_time,
< tot_time_percent), file=file)
---
> print('%s%s.%i %s%s%s%s \'%s\' %s %s %s%s --> %8.2es %4.1f%% %8.2es %4.1f%%'\
> % (prefix, a.op,
> a.outputs.index(r),
> id_str, type_str,
> nan_info,
> inf_info,
> r_name,
> destroy_map_str,
> view_map_str,
> o, test_value, op_time,
> op_time_percent,
> tot_time,
> tot_time_percent), file=file)
659c699,700
---
> recurse = nan_recurse or inf_recurse or 'ALWAYS' in recursion_rules
>
674,678c715,723
< debugprint(i, new_prefix, depth=depth - 1, done=done,
< print_type=print_type, file=file, order=order,
< ids=ids, stop_on_name=stop_on_name,
< prefix_child=new_prefix_child,
< scan_ops=scan_ops, profile=profile)
---
> if recurse:
> debugprint(i, new_prefix, depth=depth - 1, done=done,
> print_type=print_type, file=file, order=order,
> ids=ids, stop_on_name=stop_on_name,
> prefix_child=new_prefix_child, scan_ops=scan_ops,
> profile=profile, include_nan_info=include_nan_info,
> include_inf_info=include_inf_info,
> recursion_rules=','.join(recursion_rules) if recurse else '',
> print_test_value=print_test_value)
683c728
< print('%s%s %s%s' % (prefix, r, id_str, type_str), file=file)
---
> print('%s%s %s%s%s%s%s' % (prefix, r, id_str, type_str, nan_info, inf_info, test_value), file=file)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment