Created
September 14, 2018 20:05
-
-
Save ianmcook/62e28297366c78e1b284194ae709bced to your computer and use it in GitHub Desktop.
Column references in Spark DataFrame methods
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import SparkSession | |
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType | |
from pyspark.sql.functions import col, round | |
spark = SparkSession.builder.master('local').getOrCreate() | |
games = spark.createDataFrame( | |
[ | |
(1, 'Monopoly', 'Elizabeth Magie', 1903, 8, 2, 6, 19.99), | |
(2, 'Scrabble', 'Alfred Mosher Butts', 1938, 8, 2, 4, 17.99), | |
(3, 'Clue', 'Anthony E. Pratt', 1944, 8, 2, 6, 9.99), | |
(4, 'Candy Land', 'Eleanor Abbott', 1948, 3, 2, 4, 7.99), | |
(5, 'Risk', 'Albert Lamorisse', 1957, 10, 2, 5, 29.99) | |
], schema=StructType([ | |
StructField('id', IntegerType(), True), | |
StructField('name', StringType(), True), | |
StructField('inventor', StringType(), True), | |
StructField('year', IntegerType(), True), | |
StructField('min_age', IntegerType(), True), | |
StructField('min_players', IntegerType(), True), | |
StructField('max_players', IntegerType(), True), | |
StructField('list_price', FloatType(), True) | |
]) | |
) | |
# Which games cost less than ten dollars? | |
games \ | |
.filter(games.list_price < 10) \ | |
.show() | |
# If we raised the prices by two dollars, then which games would cost less than ten dollars? | |
games \ | |
.withColumn('list_price', round(games.list_price + 2.0, 2)) \ | |
.filter(games.list_price < 10) \ | |
.show() | |
# But that fails! | |
# To avoid failures like that, always use col('colname') instead of df.colname when | |
# specifying columns in arguments to DataFrame methods: | |
games \ | |
.withColumn('list_price', round(col('list_price') + 2.0, 2)) \ | |
.filter(col('list_price') < 10) \ | |
.show() | |
spark.stop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment