Skip to content

Instantly share code, notes, and snippets.

@ian-whitestone
Created July 29, 2020 16:04
Embed
What would you like to do?
Test a generic starscream stage
from __future__ import unicode_literals
import pytest
from starscream.pipeline.stage import TransformStage
from pyspark.sql import functions as F, types as T
from starscream.contract import Contract
from starscream.utils.dataframe import as_dicts, from_dicts
import pyspark.sql.types as T
contract = Contract({
'foo': {'type': T.LongType()},
'bar': {'type': T.LongType()},
})
input_df = from_dicts(sc, contract, [
{'foo': 1, 'bar': 2},
{'foo': 2, 'bar': 3},
{'foo': 3, 'bar': 4},
{'foo': 4, 'bar': 5},
{'foo': 5, 'bar': 6},
])
input_df.select(['foo', 'bar']).show()
class MyStage(TransformStage):
OUTPUT = Contract({
'foo': {'type': T.LongType()},
'bar': {'type': T.LongType()},
'baz': {'type': T.LongType()},
})
def apply(self, sc, my_input_df):
return (
my_input_df
.withColumn('baz', F.col('foo') + F.col('bar'))
)
output_df = MyStage().apply(sc, input_df)
output_df.select(['foo', 'bar', 'baz']).show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment