Skip to content

Instantly share code, notes, and snippets.

@arquolo
Last active November 21, 2022 08:38
Show Gist options
  • Save arquolo/bf595ad5067b45162f8c01cca23398b6 to your computer and use it in GitHub Desktop.
Save arquolo/bf595ad5067b45162f8c01cca23398b6 to your computer and use it in GitHub Desktop.
Fork of https://github.com/gruns/icecream with NumPy support. Python 3.6+
#!/usr/bin/env python
#
# IceCream - Never use print() to debug again
#
# Ansgar Grunseid
# grunseid.com
# grunseid@gmail.com
#
# Pavel Maevskikh
# arquolo@gmail.com
#
# License: MIT
#
# pip install asttokens colorama executing numpy pygments
#
__all__ = ['ic']
import ast
import inspect
import pprint
import shutil
import sys
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import is_dataclass, replace
from datetime import datetime
from os.path import basename
from textwrap import dedent
from typing import Dict, List, NamedTuple, Tuple
import colorama
import executing
import numpy as np
from pygments import highlight
from pygments.formatters import TerminalFormatter
from pygments.lexers.python import PythonLexer
colorama.init()
PREFIX = 'ic| '
LINE_WRAP_WIDTH = 70 # Characters
FORMATTER = TerminalFormatter(bg='dark')
LEXER = PythonLexer(ensurenl=False)
def is_literal(s) -> bool:
try:
ast.literal_eval(s)
return True
except Exception: # noqa: PIE786
return False
class NoSourceAvailableError(OSError):
"""
Raised when icecream fails to find or access source code that's
required to parse and analyze. This can happen, for example, when
- ic() is invoked inside a REPL or interactive shell, e.g. from the
command line (CLI) or with python -i.
- The source code is mangled and/or packaged, e.g. with a project
freezer like PyInstaller.
- The underlying source code changed during execution. See
https://stackoverflow.com/a/33175832.
"""
info_message = (
'Failed to access the underlying source code for analysis. Was ic() '
'invoked in a REPL (e.g. from the command line), a frozen application '
'(e.g. packaged with PyInstaller), or did the underlying source code '
'change during execution?')
class Source(executing.Source):
def get_text_with_indentation(self, node) -> str:
result = self.asttokens().get_text(node)
if '\n' in result:
result = ' ' * node.first_token.start[1] + result
result = dedent(result)
return result.strip()
def indented_lines(prefix: str, lines: str) -> List[str]:
space = ' ' * len(prefix)
first, *rest = lines.splitlines()
return [prefix + first] + [space + line for line in rest]
def format_pair(prefix: str, arg: str, value: str) -> str:
# Align the start of multiline strings.
if value[0] + value[-1] in ["''", '""']:
first, *rest = value.splitlines(keepends=True)
value = first + ''.join(' ' + line for line in rest)
*lines, tail = indented_lines(prefix, arg)
return '\n'.join(lines + indented_lines(tail + ': ', value))
def _get_nd_grad(arr: np.ndarray) -> 'Iterator[str]':
# A bit sophisticated way to compute gradients by all directions,
# but a fastest one.
# Split tensor by all axes, do mean for each cell,
# and then aggregate means to mean for the each axis split.
arr_f4 = arr.astype('f4')
# Pyramid of splits
splits: Dict[Tuple[int, ...], np.ndarray] = {(): arr_f4}
for axis, size in enumerate(arr.shape):
if size == 1:
splits = {(*k, 0): s for k, s in splits.items()}
else:
half = size // 2
splits = {
(*k, k2): ss for k, s in splits.items()
for k2, ss in enumerate(np.split(s, [half, -half], axis))
}
# Tensor of means
s_shape = [1 if size == 1 else 3 for size in arr.shape]
means = np.zeros(s_shape)
weights = np.zeros(s_shape, int)
for loc, s in splits.items():
means[loc] = s.mean() if s.size else 0
weights[loc] = s.size
# Aggregate and do grads
grad = np.zeros(arr.ndim)
for axis, size in enumerate(arr.shape):
if size == 1:
continue
axes = *(a for a in range(arr.ndim) if a != axis),
means_ = np.take(means, [0, 2], axis)
weights_ = np.take(weights, [0, 2], axis)
head, tail = np.average(means_, axes, weights_)
grad[axis] = tail - head
if grad.any():
yield f'grad={grad.round(8)}'
def _get_properties(arr: np.ndarray) -> 'Iterator[str]':
yield f'{arr.shape}, dtype={arr.dtype}'
if not arr.size:
return
# Small array, print contents as is
if arr.size < 40:
yield f'data={arr.ravel()}'
return
# Small enough binary array, hexify
if arr.size < 500 and arr.dtype == bool:
data = np.packbits(arr.flat).tobytes()
line = ''.join(f'{v:02x}' for v in data).replace('0', '_')
yield f'data={line!r}'
return
# Too much data, use statistics
lo = arr.min()
hi = arr.max()
if arr.dtype.kind == 'f':
yield f'x∈[{lo:.8f}, {hi:.8f}]'
yield f'μ={arr.mean():.8f}, σ={arr.std():.8f}'
# Wide range, only low/high
elif int(hi) - int(lo) > 100:
yield f'x∈[{lo}, {hi}]'
# Medium range or zero crossing, low/high + nuniq
elif int(lo) < 0 or int(hi) > 10:
nuniq = np.unique(arr.ravel()).size
yield f'x∈[{lo}, {hi}], nuniq={nuniq}'
# Narrow range, raw distribution
else:
weights = np.bincount(arr.ravel()).astype('f8') / arr.size
yield f'weights={weights}'
yield from _get_nd_grad(arr)
class _ReprArray(NamedTuple):
data: np.ndarray
def __str__(self) -> str:
return str(self.data)
def __repr__(self) -> str:
return 'np.ndarray(' + ', '.join(_get_properties(self.data)) + ')'
def _patch_repr_types(obj):
if isinstance(obj, np.ndarray):
return _ReprArray(obj)
if isinstance(obj, (str, bytes, bytearray, range)):
return obj
if is_dataclass(obj):
return replace(obj, **_patch_repr_types(vars(obj)))
# namedtuple
if isinstance(obj, tuple) and hasattr(obj, '_fields'):
return type(obj)(*(_patch_repr_types(x) for x in obj))
if isinstance(obj, Mapping):
return dict(_patch_repr_types(kv) for kv in obj.items())
if isinstance(obj, Iterable) and not isinstance(obj, Iterator):
return type(obj)(_patch_repr_types(x) for x in obj)
return obj
def argument_to_string(obj) -> str:
obj = _patch_repr_types(obj)
# Preserve string newlines in output.
width = shutil.get_terminal_size().columns
return pprint.pformat(obj, width=width).replace('\\n', '\n')
def _format_time() -> str:
now = f'{datetime.now():%H:%M:%S.%f}'[:-3]
return f' at {now}'
def _format_context(frame, call_node) -> str:
info = inspect.getframeinfo(frame)
parent_fn = info.function
if parent_fn != '<module>':
parent_fn = f'{parent_fn}()'
return f'{basename(info.filename)}:{call_node.lineno} in {parent_fn}'
def _construct_argument_output(context, pairs) -> str:
pairs = [(arg, argument_to_string(val)) for arg, val in pairs]
# For cleaner output, if <arg> is a literal, eg 3, "string", b'bytes',
# etc, only output the value, not the argument and the value, as the
# argument and the value will be identical or nigh identical. Ex: with
# ic("hello"), just output
#
# ic| 'hello',
#
# instead of
#
# ic| "hello": 'hello'.
#
all_args_on_one_line = ', '.join(
val if is_literal(arg) else f'{arg}: {val}' for arg, val in pairs)
context_delimiter = f'{context}- ' if context else ''
all_pairs = PREFIX + context_delimiter + all_args_on_one_line
if len(all_args_on_one_line.splitlines()) <= 1 \
and len(all_pairs.splitlines()[0]) <= LINE_WRAP_WIDTH:
# ic| foo.py:11 in foo()- a: 1, b: 2
# ic| a: 1, b: 2, c: 3
return PREFIX + context_delimiter + all_args_on_one_line
# ic| foo.py:11 in foo()
# multilineStr: 'line1
# line2'
#
# ic| foo.py:11 in foo()
# a: 11111111111111111111
# b: 22222222222222222222
if context:
space = len(PREFIX) * ' '
return '\n'.join(
[PREFIX + context] +
[format_pair(space, arg, value) for arg, value in pairs])
# ic| multilineStr: 'line1
# line2'
#
# ic| a: 11111111111111111111
# b: 22222222222222222222
lines = '\n'.join(format_pair('', arg, value) for arg, value in pairs)
return '\n'.join(indented_lines(PREFIX, lines))
def _format(frame, *args) -> str:
call_node = Source.executing(frame).node
if call_node is None:
raise NoSourceAvailableError()
context = _format_context(frame, call_node)
if not args:
return PREFIX + context + _format_time()
source = Source.for_frame(frame)
sanitized_arg_strs = [
source.get_text_with_indentation(arg) for arg in call_node.args
]
pairs = zip(sanitized_arg_strs, args)
return _construct_argument_output(context, pairs)
def ic(*args):
frame = inspect.currentframe()
assert frame
try:
out = _format(frame.f_back, *args)
except NoSourceAvailableError as err:
out = f'{PREFIX}Error: {err.info_message}'
s = highlight(out, LEXER, FORMATTER)
print(s, file=sys.stderr)
if not args:
return None # E.g. ic().
if len(args) == 1:
return args[0] # E.g. ic(1).
return args # E.g. ic(1, 2, 3).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment