Skip to content

Instantly share code, notes, and snippets.

@domenp
Last active February 14, 2017 09:35
Show Gist options
  • Save domenp/7e8b2572215e28eb75dd to your computer and use it in GitHub Desktop.
Save domenp/7e8b2572215e28eb75dd to your computer and use it in GitHub Desktop.
unit testing a spark job in python
import unittest2
import logging
import findspark
findspark.init()
from pyspark.context import SparkContext
class ExampleTest(unittest2.TestCase):
def setUp(self):
self.sc = SparkContext('local[4]')
quiet_logs(self.sc)
def tearDown(self):
self.sc.stop()
def test_something(self):
# start by creating a mockup dataset
l = [(1, 'hello'), (2, 'world'), (3, 'world')]
# and create a RDD out of it
rdd = self.sc.parallelize(l)
# pass it to the transformation you're unit testing
result = non_trivial_transform(rdd)
# collect the results
output = result.collect()
# since it's unit test let's make an assertion
self.assertEqual(output[0][1], 2)
def non_trivial_transform(rdd):
""" a transformation to unit test (word count) - defined here for convenience only"""
return rdd.map(lambda x: (x[1], 1)).reduceByKey(lambda a, b: a + b)
if __name__ == "__main__":
unittest2.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment