Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save evanfrisch/a34a24403c935d6bf49d9be4135914c7 to your computer and use it in GitHub Desktop.
Save evanfrisch/a34a24403c935d6bf49d9be4135914c7 to your computer and use it in GitHub Desktop.
def create_tag_frequencies(self, dataframe):
"""Produces a PySpark dataframe containing a column representing the total frequency of the tags by record.
The frequency of tags is determined by their proportion of the total number of tags in the dataframe.
:param dataframe: the PySpark dataframe
:returns: the PySpark dataframe containing the tag frequency field and all fields in the supplied dataframe
"""
df_tags = dataframe.selectExpr("tag1 AS tag").union(dataframe.selectExpr("tag2 AS tag")).union(dataframe.selectExpr("tag3 AS tag")) \
.union(dataframe.selectExpr("tag4 AS tag")).union(dataframe.selectExpr("tag5 AS tag"))
df_tags = df_tags.na.drop(subset=["tag"])
tags_total_count = df_tags.count()
print("Total number of tags used, including duplicates:",tags_total_count)
df_tag_freq = df_tags.groupBy("tag").count().orderBy(desc("count"))
df_tag_freq = df_tag_freq.withColumn("frequency", col("count")/tags_total_count)
df_tag_freq.orderBy(desc("frequency")).show(10)
def one_hot_encode_top_n_tags(dataframe,n):
"""Produces a PySpark dataframe containing columns indicating whether each of the top n tags are present.
:param dataframe: the PySpark dataframe
:param n: the number of the top ranked tags to return as tag fields
:returns: the PySpark dataframe containing the top n tag fields and all fields in the supplied dataframe
"""
top_n = [t.tag for t in df_tag_freq.orderBy(desc("frequency")).select("tag").limit(n).collect()]
for tag in top_n:
# replace tag name ".net" with "dotnet", for example, to avoid problems with periods in tag names
tag_column_name = ("tag_"+tag).replace(".","dot")
dataframe = dataframe.withColumn(tag_column_name, array_contains(dataframe.tags_split, tag).cast("int"))
return dataframe
dataframe = one_hot_encode_top_n_tags(dataframe,20)
tag_columns = [col for col in dataframe.columns if col.startswith('tag')]
print("Tag-related columns")
dataframe.select(tag_columns).show(10,False)
dataframe.createOrReplaceTempView('df')
df_tag_freq.createOrReplaceTempView('df_tag_freq')
for n in range(1,6):
dataframe = self.sqlContext.sql("SELECT df.*, df_tag_freq.frequency AS frequency_tag{} FROM df LEFT JOIN df_tag_freq ON df.tag{} = df_tag_freq.tag".format(n,n))
dataframe = dataframe.na.fill({"frequency_tag{}".format(n): 0})
dataframe.createOrReplaceTempView('df')
dataframe = dataframe.withColumn("frequency_sum", col("frequency_tag1")+col("frequency_tag2")+col("frequency_tag3")+col("frequency_tag4")+col("frequency_tag5"))
# Remove temporary columns
dataframe = dataframe.select([c for c in dataframe.columns if c not in {"tags_split","tag1","tag2","tag3","tag4","tag5","frequency_tag1","frequency_tag2", \
"frequency_tag3","frequency_tag4","frequency_tag5"}])
return(dataframe)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment