Created
November 10, 2019 18:28
-
-
Save dajor/2b9cbfff4134258a0d6c4d2ed4ede420 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import traceback | |
from sparktestingbase.sqltestcase import SQLTestCase | |
class SparkSQLTestCase(SQLTestCase): | |
def getConf(self): | |
from pyspark import SparkConf | |
conf = SparkConf() | |
conf.set( | |
'spark.sql.session.timeZone', 'UTC' | |
) | |
# set shuffle partitions to a low number, e.g. <= cores * 2 to speed | |
# things up, otherwise the tests will use the default 200 partitions | |
# and it will take a lot more time to complete | |
conf.set('spark.sql.shuffle.partitions', '12') | |
return conf | |
def setUp(self): | |
try: | |
from pyspark.sql import SparkSession | |
self.session = SparkSession.builder.config( | |
conf=self.getConf() | |
).appName( | |
self.__class__.__name__ | |
).getOrCreate() | |
self.sqlCtx = self.session._wrapped | |
except Exception: | |
traceback.print_exc() | |
from pyspark.sql import SQLContext | |
self.sqlCtx = SQLContext(self.sc) | |
def assertOrderedDataFrameEqual(self, expected, result, tol=0): | |
""" | |
Order both dataframes by the columns of the expected df before | |
comparing them. | |
""" | |
expected = expected.select(expected.columns).orderBy(expected.columns) | |
result = result.select(expected.columns).orderBy(expected.columns) | |
super(SQLTestCaseLatestSpark, self).assertDataFrameEqual( | |
expected, result, tol | |
) | |
def schema_nullable_helper(self, df, expected_schema, fields=None): | |
""" | |
Since column nullables cannot be easily changed after dataframe has | |
been created, given a dataframe df, an expected_schema and the fields | |
that need the nullable flag to be changed, return a dataframe with the | |
schema nullables as in the expected_schema (only for the fields | |
specified) | |
:param pyspark.sql.DataFrame df: the dataframe that needs schema | |
adjustments | |
:param pyspark.Schema expected_schema: the schema to be followed | |
:param list[str] fields: the fields that need adjustment of the | |
nullable flag | |
:return: the dataframe with the corrected nullable flags | |
:rtype: pyspark.sql.DataFrame | |
""" | |
new_schema = [] | |
current_schema = df.schema | |
if not fields: | |
fields = df.columns | |
for item in current_schema: | |
if item.name in fields: | |
for expected_item in expected_schema: | |
if expected_item.name == item.name: | |
item.nullable = expected_item.nullable | |
new_schema.append(item) | |
new_schema = StructType(new_schema) | |
df = self.session.createDataFrame( | |
df.rdd, schema=new_schema | |
) | |
return df |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment