#!/usr/bin/env python | |
""" | |
simple example script for running and testing notebooks. | |
Usage: `ipnbdoctest.py foo.ipynb [bar.ipynb [...]]` | |
Each cell is submitted to the kernel, and the outputs are compared with those stored in the notebook. | |
""" | |
import os,sys,time | |
import base64 | |
import hashlib | |
import png | |
import random | |
import re | |
from collections import defaultdict | |
from Queue import Empty | |
from cStringIO import StringIO | |
import numpy as np | |
try: | |
from IPython.kernel import KernelManager | |
except ImportError: | |
from IPython.zmq.blockingkernelmanager import BlockingKernelManager as KernelManager | |
from IPython.nbformat.current import reads, NotebookNode | |
def png_b64_to_ndarray(a64): | |
"""convert PNG output into a np.ndarray using pypng""" | |
pngdata = png.Reader(StringIO(base64.decodestring(a64))).asRGBA8()[2] | |
return np.array(list(pngdata)) | |
def diff_png(a64, b64): | |
"""compare the pixels of two PNGs""" | |
a_data, b_data = map(png_b64_to_ndarray, (a64, b64)) | |
if a_data.shape != b_data.shape: | |
diff = 255 | |
else: | |
diff = np.mean(np.abs(a_data - b_data)) | |
if diff > 0: | |
digest = hashlib.sha1(a64).digest() | |
prefix = 'ipnbdoctest-%s-' % base64.urlsafe_b64encode(digest)[:4] | |
png.from_array(a_data, mode='RGBA;8').save(prefix + 'original.png') | |
png.from_array(b_data, mode='RGBA;8').save(prefix + 'modified.png') | |
if diff < 255: | |
png.from_array(255 - np.abs(b_data - a_data), mode='RGBA;8').save( | |
prefix + 'diff.png') | |
print 'diff png saved to %s-diff.png' % prefix | |
return diff / 255. | |
def sanitize(s): | |
"""sanitize a string for comparison. | |
fix universal newlines, strip trailing newlines, and normalize likely random values (memory addresses and UUIDs) | |
""" | |
if not isinstance(s, basestring): | |
return s | |
# normalize newline: | |
s = s.replace('\r\n', '\n') | |
# ignore trailing newlines (but not space) | |
s = s.rstrip('\n') | |
# normalize hex addresses: | |
s = re.sub(r'0x[a-f0-9]+', '0xFFFFFFFF', s) | |
# normalize UUIDs: | |
s = re.sub(r'[a-f0-9]{8}(\-[a-f0-9]{4}){3}\-[a-f0-9]{12}', 'U-U-I-D', s) | |
# ignore outputs of %time and %timeit magics: | |
s = re.sub(r'(CPU times|Wall time|\d+ loops, best of).+', 'TIMING', s) | |
return s | |
def consolidate_outputs(outputs): | |
"""consolidate outputs into a summary dict (incomplete)""" | |
data = defaultdict(list) | |
data['stdout'] = '' | |
data['stderr'] = '' | |
for out in outputs: | |
if out.type == 'stream': | |
data[out.stream] += out.text | |
elif out.type == 'pyerr': | |
data['pyerr'] = dict(ename=out.ename, evalue=out.evalue) | |
else: | |
for key in ('png', 'svg', 'latex', 'html', 'javascript', 'text', 'jpeg',): | |
if key in out: | |
data[key].append(out[key]) | |
return data | |
def compare_outputs(test, ref, skip_compare=('traceback', 'latex', 'prompt_number')): | |
for key in ref: | |
if key not in test: | |
print "missing key: %s != %s" % (test.keys(), ref.keys()) | |
return False | |
elif key == 'png': | |
diff = diff_png(test[key], ref[key]) | |
if diff > 0: | |
print "mismatch %s" % key | |
print "%2.3f%% disagree" % diff | |
return False | |
elif key not in skip_compare and sanitize(test[key]) != sanitize(ref[key]): | |
print "mismatch %s:" % key | |
print test[key] | |
print ' != ' | |
print ref[key] | |
return False | |
return True | |
def run_cell(shell, iopub, cell): | |
# print cell.input | |
shell.execute(cell.input) | |
# wait for finish, maximum 60s | |
shell.get_msg(timeout=60) | |
outs = [] | |
while True: | |
try: | |
msg = iopub.get_msg(timeout=0.2) | |
except Empty: | |
break | |
msg_type = msg['msg_type'] | |
if msg_type in ('status', 'pyin'): | |
continue | |
elif msg_type == 'clear_output': | |
outs = [] | |
continue | |
content = msg['content'] | |
# print msg_type, content | |
out = NotebookNode(output_type=msg_type) | |
if msg_type == 'stream': | |
out.stream = content['name'] | |
out.text = content['data'] | |
elif msg_type in ('display_data', 'pyout'): | |
out['metadata'] = content['metadata'] | |
for mime, data in content['data'].iteritems(): | |
attr = mime.split('/')[-1].lower() | |
# this gets most right, but fix svg+html, plain | |
attr = attr.replace('+xml', '').replace('plain', 'text') | |
setattr(out, attr, data) | |
if msg_type == 'pyout': | |
out.prompt_number = content['execution_count'] | |
elif msg_type == 'pyerr': | |
out.ename = content['ename'] | |
out.evalue = content['evalue'] | |
out.traceback = content['traceback'] | |
else: | |
print "unhandled iopub msg:", msg_type | |
outs.append(out) | |
return outs | |
def test_notebook(nb): | |
km = KernelManager() | |
km.start_kernel(extra_arguments=['--pylab=inline'], stderr=open(os.devnull, 'w')) | |
try: | |
kc = km.client() | |
kc.start_channels() | |
iopub = kc.iopub_channel | |
except AttributeError: | |
# IPython 0.13 | |
kc = km | |
kc.start_channels() | |
iopub = kc.sub_channel | |
shell = kc.shell_channel | |
# run %pylab inline, because some notebooks assume this | |
# even though they shouldn't | |
shell.execute("pass") | |
shell.get_msg() | |
while True: | |
try: | |
iopub.get_msg(timeout=1) | |
except Empty: | |
break | |
successes = 0 | |
failures = 0 | |
errors = 0 | |
for ws in nb.worksheets: | |
for cell in ws.cells: | |
if cell.cell_type != 'code': | |
continue | |
try: | |
outs = run_cell(shell, iopub, cell) | |
except Exception as e: | |
print "failed to run cell:", repr(e) | |
print cell.input | |
errors += 1 | |
continue | |
failed = False | |
for out, ref in zip(outs, cell.outputs): | |
if not compare_outputs(out, ref): | |
failed = True | |
if failed: | |
failures += 1 | |
else: | |
successes += 1 | |
sys.stdout.write('.') | |
print "tested notebook %s" % nb.metadata.name | |
print " %3i cells successfully replicated" % successes | |
if failures: | |
print " %3i cells mismatched output" % failures | |
if errors: | |
print " %3i cells failed to complete" % errors | |
kc.stop_channels() | |
km.shutdown_kernel() | |
del km | |
if __name__ == '__main__': | |
for ipynb in sys.argv[1:]: | |
print "testing %s" % ipynb | |
with open(ipynb) as f: | |
nb = reads(f.read(), 'json') | |
test_notebook(nb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment