Skip to content

Instantly share code, notes, and snippets.

@theY4Kman
Last active April 9, 2020 19:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save theY4Kman/24206740b25dc5aad78129ab2e54c8c5 to your computer and use it in GitHub Desktop.
Save theY4Kman/24206740b25dc5aad78129ab2e54c8c5 to your computer and use it in GitHub Desktop.
pytest plugin to show icdiff-powered diffs for equality comparisons between collections
import ast
import traceback
from _pytest.config import Config
from icdiff import ConsoleDiff
import terminal_helpers
import util_pprint
# This is done during initialization, before any tests are run, instead of
# within our assertrepr hook — because the assertrepr hook is called while
# terminal capturing is enabled and all calls to get_terminal_width() return 80
# ref: https://github.com/pytest-dev/pytest/issues/4030#issuecomment-425672782
INITIAL_TERM_WIDTH, INITIAL_TERM_HEIGHT = terminal_helpers.get_terminal_size()
# When reporting errors, pytest prefixes each line with "E " (4 chars).
# When printing assertion failures, each line is indented with an additional 2
# spaces.
#
# By accounting for these extra spaces, we ensure no wrapping occurs.
#
PYTEST_PREFIX_LEN = 6
# Types which we support pretty diffs on
SUPPORTED_TYPES = (tuple, list, dict, set)
# Comparison operations we support pretty diffs for
SUPPORTED_OPS = {'=='}
def _format_obj(o, indent=2, width=60):
return util_pprint.pformat(o, indent=indent, width=width)
def pytest_assertrepr_compare(config: Config, op: str, left, right):
# Resets the red color from the "E" at the start of each pytest
# exception/assertion traceback line
reset_colors = lambda s: f'\x1b[0m{s}'
if op not in SUPPORTED_OPS:
return
if not (isinstance(left, SUPPORTED_TYPES) and isinstance(right, SUPPORTED_TYPES)):
return
available_width = INITIAL_TERM_WIDTH - PYTEST_PREFIX_LEN
col_width = available_width / 2 - 1 # extra space for a cleaner output
test_frame = None
frames = traceback.extract_stack()
while frames:
frame = frames.pop()
if frame.name == '_call_reprcompare':
test_frame = frames.pop()
break
left_desc = None
right_desc = None
if test_frame:
try:
left_desc, right_desc = get_assert_lhs_rhs(test_frame.line)
except Exception:
pass
if left_desc is None:
left_desc = f'{type(left).__name__}(<left>)'
if right_desc is None:
right_desc = f'{type(right).__name__}(<right>)'
rewritten_assert = f'{left_desc} {op} {right_desc}'
summary = 'Full diff:'
left_repr = _format_obj(left, width=col_width)
right_repr = _format_obj(right, width=col_width)
differ = ConsoleDiff(tabsize=4, cols=available_width)
diff = differ.make_table(
fromdesc=left_desc,
fromlines=left_repr.splitlines(),
todesc=right_desc,
tolines=right_repr.splitlines(),
)
lines = [
rewritten_assert,
'',
summary,
'',
]
lines.extend(
reset_colors(diff_line)
for diff_line in diff
)
return lines
def get_assert_lhs_rhs(source):
lines = source.splitlines()
mod = ast.parse(source)
assert_, = mod.body
left = assert_.test.left
right = assert_.test.comparators[0]
lhs = get_source_for_node(lines, left)
rhs = get_source_for_node(lines, right)
return lhs, rhs
def get_source_for_node(lines, node):
return get_source_from_line_col_offsets(
lines,
node.lineno,
node.col_offset,
node.end_lineno,
node.end_col_offset,
)
def get_source_from_line_col_offsets(lines, lineno, col_offset, end_lineno, end_col_offset):
source_lines = lines[lineno-1:end_lineno]
if len(source_lines) == 1:
source_lines[0] = source_lines[0][col_offset:end_col_offset]
else:
source_lines[0] = source_lines[0][col_offset:]
source_lines[-1] = source_lines[-1][:end_col_offset]
return '\n'.join(source_lines)
#!/usr/bin/env python
"""
Source: https://gist.github.com/jtriley/1108174
"""
import os
import shlex
import struct
import platform
import subprocess
__all__ = [
'get_terminal_size',
]
def get_terminal_size():
""" getTerminalSize()
- get width and height of console
- works on linux,os x,windows,cygwin(windows)
originally retrieved from:
http://stackoverflow.com/questions/566746/how-to-get-console-window-width-in-python
"""
current_os = platform.system()
tuple_xy = None
if current_os == 'Windows':
tuple_xy = _get_terminal_size_windows()
if tuple_xy is None:
tuple_xy = _get_terminal_size_tput()
# needed for window's python in cygwin's xterm!
if current_os in ['Linux', 'Darwin'] or current_os.startswith('CYGWIN'):
tuple_xy = _get_terminal_size_linux()
if tuple_xy is None:
tuple_xy = (80, 25) # default value
return tuple_xy
def _get_terminal_size_windows():
try:
from ctypes import windll, create_string_buffer
# stdin handle is -10
# stdout handle is -11
# stderr handle is -12
h = windll.kernel32.GetStdHandle(-12)
csbi = create_string_buffer(22)
res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi)
if res:
(bufx, bufy, curx, cury, wattr,
left, top, right, bottom,
maxx, maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw)
sizex = right - left + 1
sizey = bottom - top + 1
return sizex, sizey
except:
pass
def _get_terminal_size_tput():
# get terminal width
# src: http://stackoverflow.com/questions/263890/how-do-i-find-the-width-height-of-a-terminal-window
try:
cols = int(subprocess.check_call(shlex.split('tput cols')))
rows = int(subprocess.check_call(shlex.split('tput lines')))
return cols, rows
except:
pass
def ioctl_GWINSZ(fd):
try:
import fcntl
import termios
cr = struct.unpack('hh',
fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234'))
return cr
except:
pass
def _get_terminal_size_linux():
cr = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2)
if cr:
return int(cr[1]), int(cr[0])
try:
fd = os.open(os.ctermid(), os.O_RDONLY)
cr = ioctl_GWINSZ(fd)
os.close(fd)
except:
pass
else:
if cr:
return int(cr[1]), int(cr[0])
try:
cr = (os.environ['LINES'], os.environ['COLUMNS'])
except:
pass
else:
if cr:
return int(cr[1]), int(cr[0])
if __name__ == "__main__":
sizex, sizey = get_terminal_size()
print('width =', sizex, 'height =', sizey)
from pprint import PrettyPrinter, _builtin_scalars, _recursion
__all__ = [
'UnsortedPrettyPrinter',
'pprint',
'pformat',
]
class UnsortedPrettyPrinter(PrettyPrinter):
"""Pretty printer that retains original dict ordering
"""
def __init__(self, *args, **kwargs):
super().__init__()
self._dispatch = {
**self._dispatch,
dict.__repr__: self._pprint_dict,
}
@staticmethod
def _pprint_dict(self, object, stream, indent, allowance, context, level):
write = stream.write
write('{')
if self._indent_per_level > 1:
write((self._indent_per_level - 1) * ' ')
length = len(object)
if length:
items = object.items()
self._format_dict_items(items, stream, indent, allowance + 1,
context, level)
write('}')
def format(self, object, context, maxlevels, level):
"""Format object for a specific context, returning a string
and flags indicating whether the representation is 'readable'
and whether the object represents a recursive construct.
"""
return self._safe_repr(object, context, maxlevels, level)
def _safe_repr(self, object, context, maxlevels, level):
typ = type(object)
if typ in _builtin_scalars:
return repr(object), True, False
r = getattr(typ, "__repr__", None)
if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}", True, False
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}", False, objid in context
if objid in context:
return _recursion(object), False, True
context[objid] = 1
readable = True
recursive = False
components = []
append = components.append
level += 1
saferepr = self._safe_repr
items = object.items()
for k, v in items:
krepr, kreadable, krecur = saferepr(k, context, maxlevels, level)
vrepr, vreadable, vrecur = saferepr(v, context, maxlevels, level)
append("%s: %s" % (krepr, vrepr))
readable = readable and kreadable and vreadable
if krecur or vrecur:
recursive = True
del context[objid]
return "{%s}" % ", ".join(components), readable, recursive
if (issubclass(typ, list) and r is list.__repr__) or \
(issubclass(typ, tuple) and r is tuple.__repr__):
if issubclass(typ, list):
if not object:
return "[]", True, False
format = "[%s]"
elif len(object) == 1:
format = "(%s,)"
else:
if not object:
return "()", True, False
format = "(%s)"
objid = id(object)
if maxlevels and level >= maxlevels:
return format % "...", False, objid in context
if objid in context:
return _recursion(object), False, True
context[objid] = 1
readable = True
recursive = False
components = []
append = components.append
level += 1
for o in object:
orepr, oreadable, orecur = self._safe_repr(o, context, maxlevels, level)
append(orepr)
if not oreadable:
readable = False
if orecur:
recursive = True
del context[objid]
return format % ", ".join(components), readable, recursive
rep = repr(object)
return rep, (rep and not rep.startswith('<')), False
def pprint(object, stream=None, indent=1, width=80, depth=None, *,
compact=False):
"""Pretty-print a Python object to a stream [default is sys.stdout].
dict items are left unsorted.
"""
printer = UnsortedPrettyPrinter(
stream=stream,
indent=indent,
width=width,
depth=depth,
compact=compact,
)
printer.pprint(object)
def pformat(object, indent=1, width=80, depth=None, *, compact=False):
"""Format a Python object into a pretty-printed representation.
dict items are left unsorted.
"""
return UnsortedPrettyPrinter(
indent=indent,
width=width,
depth=depth,
compact=compact,
).pformat(object)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment