Last active
April 9, 2020 19:23
-
-
Save theY4Kman/24206740b25dc5aad78129ab2e54c8c5 to your computer and use it in GitHub Desktop.
pytest plugin to show icdiff-powered diffs for equality comparisons between collections
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import traceback | |
from _pytest.config import Config | |
from icdiff import ConsoleDiff | |
import terminal_helpers | |
import util_pprint | |
# This is done during initialization, before any tests are run, instead of | |
# within our assertrepr hook — because the assertrepr hook is called while | |
# terminal capturing is enabled and all calls to get_terminal_width() return 80 | |
# ref: https://github.com/pytest-dev/pytest/issues/4030#issuecomment-425672782 | |
INITIAL_TERM_WIDTH, INITIAL_TERM_HEIGHT = terminal_helpers.get_terminal_size() | |
# When reporting errors, pytest prefixes each line with "E " (4 chars). | |
# When printing assertion failures, each line is indented with an additional 2 | |
# spaces. | |
# | |
# By accounting for these extra spaces, we ensure no wrapping occurs. | |
# | |
PYTEST_PREFIX_LEN = 6 | |
# Types which we support pretty diffs on | |
SUPPORTED_TYPES = (tuple, list, dict, set) | |
# Comparison operations we support pretty diffs for | |
SUPPORTED_OPS = {'=='} | |
def _format_obj(o, indent=2, width=60): | |
return util_pprint.pformat(o, indent=indent, width=width) | |
def pytest_assertrepr_compare(config: Config, op: str, left, right): | |
# Resets the red color from the "E" at the start of each pytest | |
# exception/assertion traceback line | |
reset_colors = lambda s: f'\x1b[0m{s}' | |
if op not in SUPPORTED_OPS: | |
return | |
if not (isinstance(left, SUPPORTED_TYPES) and isinstance(right, SUPPORTED_TYPES)): | |
return | |
available_width = INITIAL_TERM_WIDTH - PYTEST_PREFIX_LEN | |
col_width = available_width / 2 - 1 # extra space for a cleaner output | |
test_frame = None | |
frames = traceback.extract_stack() | |
while frames: | |
frame = frames.pop() | |
if frame.name == '_call_reprcompare': | |
test_frame = frames.pop() | |
break | |
left_desc = None | |
right_desc = None | |
if test_frame: | |
try: | |
left_desc, right_desc = get_assert_lhs_rhs(test_frame.line) | |
except Exception: | |
pass | |
if left_desc is None: | |
left_desc = f'{type(left).__name__}(<left>)' | |
if right_desc is None: | |
right_desc = f'{type(right).__name__}(<right>)' | |
rewritten_assert = f'{left_desc} {op} {right_desc}' | |
summary = 'Full diff:' | |
left_repr = _format_obj(left, width=col_width) | |
right_repr = _format_obj(right, width=col_width) | |
differ = ConsoleDiff(tabsize=4, cols=available_width) | |
diff = differ.make_table( | |
fromdesc=left_desc, | |
fromlines=left_repr.splitlines(), | |
todesc=right_desc, | |
tolines=right_repr.splitlines(), | |
) | |
lines = [ | |
rewritten_assert, | |
'', | |
summary, | |
'', | |
] | |
lines.extend( | |
reset_colors(diff_line) | |
for diff_line in diff | |
) | |
return lines | |
def get_assert_lhs_rhs(source): | |
lines = source.splitlines() | |
mod = ast.parse(source) | |
assert_, = mod.body | |
left = assert_.test.left | |
right = assert_.test.comparators[0] | |
lhs = get_source_for_node(lines, left) | |
rhs = get_source_for_node(lines, right) | |
return lhs, rhs | |
def get_source_for_node(lines, node): | |
return get_source_from_line_col_offsets( | |
lines, | |
node.lineno, | |
node.col_offset, | |
node.end_lineno, | |
node.end_col_offset, | |
) | |
def get_source_from_line_col_offsets(lines, lineno, col_offset, end_lineno, end_col_offset): | |
source_lines = lines[lineno-1:end_lineno] | |
if len(source_lines) == 1: | |
source_lines[0] = source_lines[0][col_offset:end_col_offset] | |
else: | |
source_lines[0] = source_lines[0][col_offset:] | |
source_lines[-1] = source_lines[-1][:end_col_offset] | |
return '\n'.join(source_lines) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
""" | |
Source: https://gist.github.com/jtriley/1108174 | |
""" | |
import os | |
import shlex | |
import struct | |
import platform | |
import subprocess | |
__all__ = [ | |
'get_terminal_size', | |
] | |
def get_terminal_size(): | |
""" getTerminalSize() | |
- get width and height of console | |
- works on linux,os x,windows,cygwin(windows) | |
originally retrieved from: | |
http://stackoverflow.com/questions/566746/how-to-get-console-window-width-in-python | |
""" | |
current_os = platform.system() | |
tuple_xy = None | |
if current_os == 'Windows': | |
tuple_xy = _get_terminal_size_windows() | |
if tuple_xy is None: | |
tuple_xy = _get_terminal_size_tput() | |
# needed for window's python in cygwin's xterm! | |
if current_os in ['Linux', 'Darwin'] or current_os.startswith('CYGWIN'): | |
tuple_xy = _get_terminal_size_linux() | |
if tuple_xy is None: | |
tuple_xy = (80, 25) # default value | |
return tuple_xy | |
def _get_terminal_size_windows(): | |
try: | |
from ctypes import windll, create_string_buffer | |
# stdin handle is -10 | |
# stdout handle is -11 | |
# stderr handle is -12 | |
h = windll.kernel32.GetStdHandle(-12) | |
csbi = create_string_buffer(22) | |
res = windll.kernel32.GetConsoleScreenBufferInfo(h, csbi) | |
if res: | |
(bufx, bufy, curx, cury, wattr, | |
left, top, right, bottom, | |
maxx, maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw) | |
sizex = right - left + 1 | |
sizey = bottom - top + 1 | |
return sizex, sizey | |
except: | |
pass | |
def _get_terminal_size_tput(): | |
# get terminal width | |
# src: http://stackoverflow.com/questions/263890/how-do-i-find-the-width-height-of-a-terminal-window | |
try: | |
cols = int(subprocess.check_call(shlex.split('tput cols'))) | |
rows = int(subprocess.check_call(shlex.split('tput lines'))) | |
return cols, rows | |
except: | |
pass | |
def ioctl_GWINSZ(fd): | |
try: | |
import fcntl | |
import termios | |
cr = struct.unpack('hh', | |
fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234')) | |
return cr | |
except: | |
pass | |
def _get_terminal_size_linux(): | |
cr = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2) | |
if cr: | |
return int(cr[1]), int(cr[0]) | |
try: | |
fd = os.open(os.ctermid(), os.O_RDONLY) | |
cr = ioctl_GWINSZ(fd) | |
os.close(fd) | |
except: | |
pass | |
else: | |
if cr: | |
return int(cr[1]), int(cr[0]) | |
try: | |
cr = (os.environ['LINES'], os.environ['COLUMNS']) | |
except: | |
pass | |
else: | |
if cr: | |
return int(cr[1]), int(cr[0]) | |
if __name__ == "__main__": | |
sizex, sizey = get_terminal_size() | |
print('width =', sizex, 'height =', sizey) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pprint import PrettyPrinter, _builtin_scalars, _recursion | |
__all__ = [ | |
'UnsortedPrettyPrinter', | |
'pprint', | |
'pformat', | |
] | |
class UnsortedPrettyPrinter(PrettyPrinter): | |
"""Pretty printer that retains original dict ordering | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self._dispatch = { | |
**self._dispatch, | |
dict.__repr__: self._pprint_dict, | |
} | |
@staticmethod | |
def _pprint_dict(self, object, stream, indent, allowance, context, level): | |
write = stream.write | |
write('{') | |
if self._indent_per_level > 1: | |
write((self._indent_per_level - 1) * ' ') | |
length = len(object) | |
if length: | |
items = object.items() | |
self._format_dict_items(items, stream, indent, allowance + 1, | |
context, level) | |
write('}') | |
def format(self, object, context, maxlevels, level): | |
"""Format object for a specific context, returning a string | |
and flags indicating whether the representation is 'readable' | |
and whether the object represents a recursive construct. | |
""" | |
return self._safe_repr(object, context, maxlevels, level) | |
def _safe_repr(self, object, context, maxlevels, level): | |
typ = type(object) | |
if typ in _builtin_scalars: | |
return repr(object), True, False | |
r = getattr(typ, "__repr__", None) | |
if issubclass(typ, dict) and r is dict.__repr__: | |
if not object: | |
return "{}", True, False | |
objid = id(object) | |
if maxlevels and level >= maxlevels: | |
return "{...}", False, objid in context | |
if objid in context: | |
return _recursion(object), False, True | |
context[objid] = 1 | |
readable = True | |
recursive = False | |
components = [] | |
append = components.append | |
level += 1 | |
saferepr = self._safe_repr | |
items = object.items() | |
for k, v in items: | |
krepr, kreadable, krecur = saferepr(k, context, maxlevels, level) | |
vrepr, vreadable, vrecur = saferepr(v, context, maxlevels, level) | |
append("%s: %s" % (krepr, vrepr)) | |
readable = readable and kreadable and vreadable | |
if krecur or vrecur: | |
recursive = True | |
del context[objid] | |
return "{%s}" % ", ".join(components), readable, recursive | |
if (issubclass(typ, list) and r is list.__repr__) or \ | |
(issubclass(typ, tuple) and r is tuple.__repr__): | |
if issubclass(typ, list): | |
if not object: | |
return "[]", True, False | |
format = "[%s]" | |
elif len(object) == 1: | |
format = "(%s,)" | |
else: | |
if not object: | |
return "()", True, False | |
format = "(%s)" | |
objid = id(object) | |
if maxlevels and level >= maxlevels: | |
return format % "...", False, objid in context | |
if objid in context: | |
return _recursion(object), False, True | |
context[objid] = 1 | |
readable = True | |
recursive = False | |
components = [] | |
append = components.append | |
level += 1 | |
for o in object: | |
orepr, oreadable, orecur = self._safe_repr(o, context, maxlevels, level) | |
append(orepr) | |
if not oreadable: | |
readable = False | |
if orecur: | |
recursive = True | |
del context[objid] | |
return format % ", ".join(components), readable, recursive | |
rep = repr(object) | |
return rep, (rep and not rep.startswith('<')), False | |
def pprint(object, stream=None, indent=1, width=80, depth=None, *, | |
compact=False): | |
"""Pretty-print a Python object to a stream [default is sys.stdout]. | |
dict items are left unsorted. | |
""" | |
printer = UnsortedPrettyPrinter( | |
stream=stream, | |
indent=indent, | |
width=width, | |
depth=depth, | |
compact=compact, | |
) | |
printer.pprint(object) | |
def pformat(object, indent=1, width=80, depth=None, *, compact=False): | |
"""Format a Python object into a pretty-printed representation. | |
dict items are left unsorted. | |
""" | |
return UnsortedPrettyPrinter( | |
indent=indent, | |
width=width, | |
depth=depth, | |
compact=compact, | |
).pformat(object) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment