Skip to content

Instantly share code, notes, and snippets.

@calmofthestorm
Last active August 29, 2015 14:00
Show Gist options
  • Save calmofthestorm/45e1688f6964d1506a67 to your computer and use it in GitHub Desktop.
Save calmofthestorm/45e1688f6964d1506a67 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
# encoding: utf-8
import argparse
import json
import os
import select
import socket
import subprocess
import sys
READ_CHUNK = 2048
class Error(Exception):
pass
class BadStateError(Error):
pass
class UnknownGTestVersionError(Error):
pass
class TesterClosedSocket(Error):
pass
class Shard(object):
def __init__(self, sock):
self.sock = sock
self.buf = []
self.alive = True
self.version = None
self.started = False
self.iteration = None
self.case = None
self.test = None
self.state = {None: self.state_initial}
self.results = {}
self.passed = None
def process_event(self, event_string):
event = dict(kv.split('=', 1) for kv in event_string.split('&'))
error = BadStateError('Expected next states ' +
str(self.state.keys())
+ ' got ' + str(event.get('event', None)))
if not self.alive:
raise error
try:
next_state = self.state[event.get('event', None)]
except KeyError:
raise error
self.state = next_state(event)
self.alive = self.state != {}
return self.alive
def read_events(self):
tmp = self.sock.recv(READ_CHUNK)
if not tmp:
raise TesterClosedSocket()
while '\n' in tmp:
old, tmp = tmp.split('\n', 1)
self.buf.append(old)
self.process_event(''.join(self.buf))
self.buf = []
self.buf.append(tmp)
def state_initial(self, event):
if 'gtest_streaming_protocol_version' not in event:
raise BadStateError('Expected version information.')
self.version = event['gtest_streaming_protocol_version']
if self.version != '1.0':
raise UnknownGTestVersionError(self.version)
return {'TestProgramStart': self.state_program_start}
def state_program_start(self, event):
self.started = True
return {'TestIterationStart': self.state_iteration_start,
'TestProgramEnd': self.state_program_end}
def state_iteration_start(self, event):
self._expect_keys(event, 'iteration')
assert self.iteration is None
self.iteration = int(event['iteration'])
self.results[self.iteration] = {}
return {'TestCaseStart': self.state_case_start,
'TestIterationEnd': self.state_iteration_end}
def state_case_start(self, event):
self._expect_keys(event, 'name')
assert self.case is None
self.case = event['name']
assert self.case not in self.results[self.iteration]
self.results[self.iteration][self.case] = {}
return {'TestStart': self.state_test_start,
'TestCaseEnd': self.state_case_end}
def state_test_start(self, event):
self._expect_keys(event, 'name')
assert self.test is None
self.test = event['name']
return {'TestEnd': self.state_test_end}
def state_test_end(self, event):
self._expect_keys(event, 'passed')
assert self.test is not None and self.case is not None
assert self.test not in self.results[self.iteration][self.case]
res = {'passed': int(event['passed'])}
if 'elapsed_time' in event:
res['elapsed_time'] = event['elapsed_time']
self.results[self.iteration][self.case][self.test] = res
self.test = None
return {'TestStart': self.state_test_start,
'TestCaseEnd': self.state_case_end}
def state_case_end(self, event):
self._expect_keys(event, 'passed')
assert self.case is not None and self.test is None
assert self.case in self.results[self.iteration]
res = {'passed': int(event['passed'])}
if 'elapsed_time' in event:
res['elapsed_time'] = event['elapsed_time']
subres = self.results[self.iteration][self.case]
self.results[self.iteration][self.case] = (res, subres)
self.case = None
return {'TestCaseStart': self.state_case_start,
'TestIterationEnd': self.state_iteration_end}
def state_iteration_end(self, event):
self._expect_keys(event, 'passed')
assert self.case is None and self.test is None
assert self.iteration is not None
assert self.iteration in self.results
res = {'passed': int(event['passed'])}
if 'elapsed_time' in event:
res['elapsed_time'] = event['elapsed_time']
subres = self.results[self.iteration]
self.results[self.iteration] = (res, subres)
self.iteration = None
return {'TestIterationStart': self.state_iteration_start,
'TestProgramEnd': self.state_program_end}
def state_program_end(self, event):
self._expect_keys(event, 'passed')
self.passed = int(event['passed'])
assert self.iteration is None and self.test is None
assert self.case is None
return {}
def _expect_keys(self, event, *keys):
for key in keys:
if key not in event:
raise BadStateError('Expected event to have key %s.' % key)
def add_times(a, b):
assert a.endswith('ms') and b.endswith('ms')
count = int(a[:-2]) + int(b[:-2])
return '%ims' % count
def accumulate_time(acc, new):
try:
acc['elapsed_time'] = add_times(acc['elapsed_time'], new['elapsed_time'])
except KeyError:
if 'elapsed_time' in acc:
del acc
def accumulate_results(this_runner, full_results, cross_iter, results):
full_results.append(this_runner)
for iteration, (_, iresults) in this_runner.iteritems():
results.setdefault(iteration, {})
for case, (_, case_results) in iresults.iteritems():
results.setdefault(case, {})
cross_iter.setdefault(case, {})
for test, test_results in case_results.iteritems():
results[iteration][case] = test_results
if test not in cross_iter[case]:
cross_iter[case][test] = test_results
cross_iter[case][test].setdefault('total', 1)
else:
cross_iter[case][test]['passed'] += test_results['passed']
cross_iter[case][test]['total'] += 1
accumulate_time(cross_iter[case][test], test_results)
def print_summary(results):
failed = False
max_len = max(map(len, results)) if results else 0
for case, res in results.iteritems():
green = all(meta['passed'] == meta['total']
for (_, meta) in res.iteritems())
failed |= not green
sys.stdout.write('\033[92m\033[1m' if green else '\033[91m\033[1m')
print case + ' ' * (max_len - len(case) + 5),
tottime = '0ms'
for test, meta in res.iteritems():
add_times(meta['elapsed_time'], tottime)
if meta['passed'] == meta['total']:
sys.stdout.write('.')
else:
sys.stdout.write('F')
print '\033[0m'
print '-' * (max_len + 20)
if failed:
print '\033[91m\033[1mThere were failed tests.\033[0m'
else:
print '\033[92m\033[1mAll tests passed!\033[0m'
def make_parser():
'''Construct the commandline argument parser.'''
p = argparse.ArgumentParser(description='Run gtest in parallel.')
p.add_argument('gtest_binary', metavar='gtest_binary', type=str,
help='Path to the gtest binary to run.')
p.add_argument('num_shards', metavar='num_shards', type=int,
help='How many shard processes to split the tests into.')
p.add_argument('-q', dest='quiet', action='store_true',
help='Don\'t print compact text summary.')
p.add_argument('-j', dest='json', action='store_true',
help='Output test results in JSON.')
p.add_argument('-s', dest='nosquelch', action='store_true',
help='Don\'t squelch native gtest output.')
p.add_argument('gtest_flags', metavar='gtest_flags', type=str, nargs='*',
help='Flags to pass to gtest.')
return p
def main():
args = make_parser().parse_args()
sock = socket.socket()
sock.bind(('localhost', 0))
sock.listen(args.num_shards)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = sock.getsockname()[1]
cli = args.gtest_flags + ['--gtest_stream_result_to=localhost:%i' % port]
workers = []
outpipe = None if args.nosquelch else subprocess.PIPE
for i in xrange(args.num_shards):
env = {'GTEST_SHARD_INDEX': str(i),
'GTEST_TOTAL_SHARDS': str(args.num_shards)}
env.update(os.environ)
workers.append(subprocess.Popen([args.gtest_binary] + cli, env=env,
stdout=outpipe, stderr=outpipe))
sockets = [sock.accept()[0] for _ in xrange(args.num_shards)]
shards = dict((s, Shard(s)) for s in sockets)
results = {}
cross_iter = {}
full_results = []
while shards:
for fd in select.select(shards.keys(), (), ())[0]:
try:
shards[fd].read_events()
except TesterClosedSocket:
accumulate_results(shards[fd].results, full_results,
cross_iter, results)
del shards[fd]
if not args.quiet:
print_summary(cross_iter)
if args.json:
obj = {'results': results,
'full_results': full_results,
'summary': cross_iter}
json.dump(obj, sys.stdout)
green = all(all(meta['passed'] == meta['total']
for (_, meta) in res.iteritems())
for (_, res) in cross_iter.iteritems())
if green:
return 0
else:
return 1
if __name__ == '__main__':
sys.exit(main())
# vim: sta:et:sw=4:ts=4:sts=4
import unittest
import gatherer
# Alex Roper
# alex@aroper.net
VERSION = 'gtest_streaming_protocol_version=%s'
START = 'event=TestProgramStart'
ITERATION_START = 'event=TestIterationStart&iteration=%s'
CASE_START = 'event=TestCaseStart&name=%s'
TEST_START = 'event=TestStart&name=%s'
TEST_END = 'event=TestEnd&passed=%i&elapsed_time=%ims'
CASE_END = 'event=TestCaseEnd&passed=%i&elapsed_time=%ims'
ITERATION_END = 'event=TestIterationEnd&passed=%i&elapsed_time=%ims'
END = 'event=TestProgramEnd&passed=%i'
class TestStateMachineBase(unittest.TestCase):
def setUp(self):
self.version = VERSION % '1.0'
self.start = START
self.iteration_start = ITERATION_START % 0
self.case_start = CASE_START % 'FooCase'
self.test_start = TEST_START % 'FooTest'
self.test_end = TEST_END % (1, 3)
self.case_end = CASE_END % (1, 3)
self.iteration_end = ITERATION_END % (1, 3)
self.end = END % 1
class TestStateMachineCases(TestStateMachineBase):
def prelude(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
return s
def postlude(self, s):
s.process_event(self.iteration_end)
self.assertFalse(s.process_event(self.end))
def test_simple1(self):
s = self.prelude()
self.assertTrue(s.process_event(self.case_start))
self.assertTrue(s.process_event(self.test_start))
self.assertTrue(s.process_event(self.test_end))
self.assertTrue(s.process_event(self.case_end))
self.postlude(s)
def test_simple2(self):
s = self.prelude()
self.assertTrue(s.process_event(self.case_start))
self.assertTrue(s.process_event(self.test_start))
self.assertTrue(s.process_event(self.test_end))
self.assertTrue(s.process_event(TEST_START % 'test2'))
self.assertTrue(s.process_event(self.test_end))
self.assertTrue(s.process_event(self.case_end))
self.postlude(s)
def test_simple3(self):
s = self.prelude()
self.assertTrue(s.process_event(self.case_start))
self.assertTrue(s.process_event(self.test_start))
self.assertTrue(s.process_event(self.test_end))
self.assertTrue(s.process_event(self.case_end))
self.assertTrue(s.process_event(CASE_START % 'case2'))
self.assertTrue(s.process_event(TEST_START % 'test2'))
self.assertTrue(s.process_event(self.test_end))
self.assertTrue(s.process_event(self.case_end))
self.postlude(s)
class TestStateMachineEdge(TestStateMachineBase):
def test_no_iterations(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertFalse(s.process_event(self.end))
def test_no_cases(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertTrue(s.process_event(self.iteration_end))
self.assertFalse(s.process_event(self.end))
def test_no_tests(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertTrue(s.process_event(self.case_start))
self.assertTrue(s.process_event(self.case_end))
self.assertTrue(s.process_event(self.iteration_end))
self.assertFalse(s.process_event(self.end))
class TestStateMachineBadTransitions(TestStateMachineBase):
def prelude_through_case(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertTrue(s.process_event(self.case_start))
return s
def test_no_iteration_end(self):
s = self.prelude_through_case()
self.assertTrue(s.process_event(self.case_end))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.end
)
def test_no_case_end1(self):
s = self.prelude_through_case()
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.iteration_end
)
def test_no_case_end2(self):
s = self.prelude_through_case()
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.end
)
def test_no_test_end1(self):
s = self.prelude_through_case()
self.assertTrue(s.process_event(self.test_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.case_end
)
def test_no_test_end2(self):
s = self.prelude_through_case()
self.assertTrue(s.process_event(self.test_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.iteration_end
)
def test_no_test_end3(self):
s = self.prelude_through_case()
self.assertTrue(s.process_event(self.test_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.end
)
def test_case_outside_iteration(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.case_start
)
def test_test_outside_case(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.test_start
)
def test_test_outside_iteration(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.test_start
)
def test_start_inside_start(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
self.start
)
def test_iteration_inside_iteration(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
ITERATION_START % 1
)
def test_case_inside_case(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertTrue(s.process_event(self.case_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
CASE_START % 'case2'
)
def test_test_inside_test(self):
s = gatherer.Shard(0)
self.assertTrue(s.process_event(self.version))
self.assertTrue(s.process_event(self.start))
self.assertTrue(s.process_event(self.iteration_start))
self.assertTrue(s.process_event(self.case_start))
self.assertTrue(s.process_event(self.test_start))
self.assertRaises(
gatherer.BadStateError,
s.process_event,
TEST_START % 'test2'
)
if __name__ == '__main__':
unittest.main()
# vim: sta:et:sw=4:ts=4:sts=4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment