Last active
February 14, 2017 09:35
-
-
Save domenp/7e8b2572215e28eb75dd to your computer and use it in GitHub Desktop.
unit testing a spark job in python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest2 | |
import logging | |
import findspark | |
findspark.init() | |
from pyspark.context import SparkContext | |
class ExampleTest(unittest2.TestCase): | |
def setUp(self): | |
self.sc = SparkContext('local[4]') | |
quiet_logs(self.sc) | |
def tearDown(self): | |
self.sc.stop() | |
def test_something(self): | |
# start by creating a mockup dataset | |
l = [(1, 'hello'), (2, 'world'), (3, 'world')] | |
# and create a RDD out of it | |
rdd = self.sc.parallelize(l) | |
# pass it to the transformation you're unit testing | |
result = non_trivial_transform(rdd) | |
# collect the results | |
output = result.collect() | |
# since it's unit test let's make an assertion | |
self.assertEqual(output[0][1], 2) | |
def non_trivial_transform(rdd): | |
""" a transformation to unit test (word count) - defined here for convenience only""" | |
return rdd.map(lambda x: (x[1], 1)).reduceByKey(lambda a, b: a + b) | |
if __name__ == "__main__": | |
unittest2.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment