Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
#!/usr/bin/env python
# encoding: utf-8
# This file lives in tests/project_test.py in the usual disutils structure
# Remember to set the SPARK_HOME evnironment variable to the path of your spark installation
import logging
import sys
import unittest
from nose.tools import eq_, set_trace
def add_pyspark_path():
"""
Add PySpark to the PYTHONPATH
Thanks go to this project: https://github.com/holdenk/sparklingpandas
"""
import sys
import os
try:
sys.path.append(os.path.join(os.environ['SPARK_HOME'], "python"))
sys.path.append(os.path.join(os.environ['SPARK_HOME'],
"python","lib","py4j-0.9-src.zip"))
except KeyError:
print "SPARK_HOME not set"
sys.exit(1)
add_pyspark_path() # Now we can import pyspark
from pyspark import SparkContext
from pyspark import SparkConf
from pyspark.sql import SQLContext, HiveContext
from pyspark.sql.window import Window
import pyspark.sql.functions as func
def quiet_py4j():
""" turn down spark logging for the test context """
logger = logging.getLogger('py4j')
logger.setLevel(logging.WARN)
class GSparkTestCase(unittest.TestCase):
def setUp(self):
quiet_py4j()
# Setup a new spark context for each test
conf = SparkConf()
conf.set("spark.executor.memory","1g")
conf.set("spark.cores.max", "1")
#conf.set("spark.master", "spark://192.168.1.2:7077")
conf.set("spark.app.name", "nosetest")
self.sc = SparkContext(conf=conf)
self.sqlContext = HiveContext(self.sc)
def tearDown(self):
self.sc.stop()
# This would go in tests/project_test.py
class BasicSparkTests(GSparkTestCase):
def null_test(self):
df = self.sqlContext.createDataFrame([
(1, 1, None),
(1, 2, 109),
(1, 3, None),
(1, 4, None),
(1, 5, 109),
(1, 6, None),
(1, 7, 110),
(1, 8, None),
(1, 9, None),
], ("session", "timestamp", "id"))
eq_(df.count(), 9)
def process(df):
df_na = df.na.fill(-1)
lag = df_na.withColumn('id_lag', func.lag('id', default=-1)\
.over(Window.partitionBy('session')\
.orderBy('timestamp')))
switch = lag.withColumn('id_change',
((lag['id'] != lag['id_lag']) &
(lag['id'] != -1)).cast('integer'))
switch_sess = switch.withColumn(
'sub_session',
func.sum("id_change")
.over(
Window.partitionBy("session")
.orderBy("timestamp")
.rowsBetween(-sys.maxsize, 0))
)
fid = switch_sess.withColumn('nn_id',
func.first('id')\
.over(Window.partitionBy('sub_session')\
.orderBy('timestamp')))
fid_na = fid.replace(-1, 'null')
ff = fid_na.drop('id').drop('id_lag')\
.drop('id_change')\
.drop('sub_session').\
withColumnRenamed('nn_id', 'id')
return ff
df_filled = process(df)
df_exp = self.sqlContext.createDataFrame([
(1, 1, None),
(1, 2, 109),
(1, 3, 109),
(1, 4, 109),
(1, 5, 109),
(1, 6, 109),
(1, 7, 110),
(1, 8, 110),
(1, 9, 110),
], ("session", "timestamp", "id"))
eq_(df_filled.collect(), df_exp.collect())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.