Created
October 31, 2017 01:19
-
-
Save evanfrisch/a34a24403c935d6bf49d9be4135914c7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_tag_frequencies(self, dataframe): | |
"""Produces a PySpark dataframe containing a column representing the total frequency of the tags by record. | |
The frequency of tags is determined by their proportion of the total number of tags in the dataframe. | |
:param dataframe: the PySpark dataframe | |
:returns: the PySpark dataframe containing the tag frequency field and all fields in the supplied dataframe | |
""" | |
df_tags = dataframe.selectExpr("tag1 AS tag").union(dataframe.selectExpr("tag2 AS tag")).union(dataframe.selectExpr("tag3 AS tag")) \ | |
.union(dataframe.selectExpr("tag4 AS tag")).union(dataframe.selectExpr("tag5 AS tag")) | |
df_tags = df_tags.na.drop(subset=["tag"]) | |
tags_total_count = df_tags.count() | |
print("Total number of tags used, including duplicates:",tags_total_count) | |
df_tag_freq = df_tags.groupBy("tag").count().orderBy(desc("count")) | |
df_tag_freq = df_tag_freq.withColumn("frequency", col("count")/tags_total_count) | |
df_tag_freq.orderBy(desc("frequency")).show(10) | |
def one_hot_encode_top_n_tags(dataframe,n): | |
"""Produces a PySpark dataframe containing columns indicating whether each of the top n tags are present. | |
:param dataframe: the PySpark dataframe | |
:param n: the number of the top ranked tags to return as tag fields | |
:returns: the PySpark dataframe containing the top n tag fields and all fields in the supplied dataframe | |
""" | |
top_n = [t.tag for t in df_tag_freq.orderBy(desc("frequency")).select("tag").limit(n).collect()] | |
for tag in top_n: | |
# replace tag name ".net" with "dotnet", for example, to avoid problems with periods in tag names | |
tag_column_name = ("tag_"+tag).replace(".","dot") | |
dataframe = dataframe.withColumn(tag_column_name, array_contains(dataframe.tags_split, tag).cast("int")) | |
return dataframe | |
dataframe = one_hot_encode_top_n_tags(dataframe,20) | |
tag_columns = [col for col in dataframe.columns if col.startswith('tag')] | |
print("Tag-related columns") | |
dataframe.select(tag_columns).show(10,False) | |
dataframe.createOrReplaceTempView('df') | |
df_tag_freq.createOrReplaceTempView('df_tag_freq') | |
for n in range(1,6): | |
dataframe = self.sqlContext.sql("SELECT df.*, df_tag_freq.frequency AS frequency_tag{} FROM df LEFT JOIN df_tag_freq ON df.tag{} = df_tag_freq.tag".format(n,n)) | |
dataframe = dataframe.na.fill({"frequency_tag{}".format(n): 0}) | |
dataframe.createOrReplaceTempView('df') | |
dataframe = dataframe.withColumn("frequency_sum", col("frequency_tag1")+col("frequency_tag2")+col("frequency_tag3")+col("frequency_tag4")+col("frequency_tag5")) | |
# Remove temporary columns | |
dataframe = dataframe.select([c for c in dataframe.columns if c not in {"tags_split","tag1","tag2","tag3","tag4","tag5","frequency_tag1","frequency_tag2", \ | |
"frequency_tag3","frequency_tag4","frequency_tag5"}]) | |
return(dataframe) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment