Skip to content

Instantly share code, notes, and snippets.

@dajor
Created November 10, 2019 18:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dajor/2b9cbfff4134258a0d6c4d2ed4ede420 to your computer and use it in GitHub Desktop.
Save dajor/2b9cbfff4134258a0d6c4d2ed4ede420 to your computer and use it in GitHub Desktop.
import traceback
from sparktestingbase.sqltestcase import SQLTestCase
class SparkSQLTestCase(SQLTestCase):
def getConf(self):
from pyspark import SparkConf
conf = SparkConf()
conf.set(
'spark.sql.session.timeZone', 'UTC'
)
# set shuffle partitions to a low number, e.g. <= cores * 2 to speed
# things up, otherwise the tests will use the default 200 partitions
# and it will take a lot more time to complete
conf.set('spark.sql.shuffle.partitions', '12')
return conf
def setUp(self):
try:
from pyspark.sql import SparkSession
self.session = SparkSession.builder.config(
conf=self.getConf()
).appName(
self.__class__.__name__
).getOrCreate()
self.sqlCtx = self.session._wrapped
except Exception:
traceback.print_exc()
from pyspark.sql import SQLContext
self.sqlCtx = SQLContext(self.sc)
def assertOrderedDataFrameEqual(self, expected, result, tol=0):
"""
Order both dataframes by the columns of the expected df before
comparing them.
"""
expected = expected.select(expected.columns).orderBy(expected.columns)
result = result.select(expected.columns).orderBy(expected.columns)
super(SQLTestCaseLatestSpark, self).assertDataFrameEqual(
expected, result, tol
)
def schema_nullable_helper(self, df, expected_schema, fields=None):
"""
Since column nullables cannot be easily changed after dataframe has
been created, given a dataframe df, an expected_schema and the fields
that need the nullable flag to be changed, return a dataframe with the
schema nullables as in the expected_schema (only for the fields
specified)
:param pyspark.sql.DataFrame df: the dataframe that needs schema
adjustments
:param pyspark.Schema expected_schema: the schema to be followed
:param list[str] fields: the fields that need adjustment of the
nullable flag
:return: the dataframe with the corrected nullable flags
:rtype: pyspark.sql.DataFrame
"""
new_schema = []
current_schema = df.schema
if not fields:
fields = df.columns
for item in current_schema:
if item.name in fields:
for expected_item in expected_schema:
if expected_item.name == item.name:
item.nullable = expected_item.nullable
new_schema.append(item)
new_schema = StructType(new_schema)
df = self.session.createDataFrame(
df.rdd, schema=new_schema
)
return df
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment