Skip to content

Instantly share code, notes, and snippets.

@spektom
Last active July 12, 2017 13:42
Show Gist options
  • Save spektom/7b7cd4a85739083dbfc93d673cbbd47f to your computer and use it in GitHub Desktop.
Save spektom/7b7cd4a85739083dbfc93d673cbbd47f to your computer and use it in GitHub Desktop.
/**
* Protect cardinality of a column in a data frame
*
* @param df Input data frame
* @param col Column which cardinality will be protected
* @param metric Metric to use when ranking column values
* @param groupBy Additional columns to group by when
* calculating column value ranks
* @param limit Max number of distinct values in a column
* @param replaceWith Static text to replace all values that
* exceed the limit with
* @return Data frame with fixed cardinality of a column
*/
def protectDF(df: DataFrame, col: Column, metric: Column, groupBy: Seq[Column], limit: Int, replaceWith: String)= {
import df.sqlContext.implicits._
df.select($"*",
sum(metric).over(Window.partitionBy(groupBy :+ col: _*)).as("_group_total"))
.withColumn("_card_rank",
dense_rank.over(Window.partitionBy(groupBy: _*).orderBy($"_group_total".desc, col)))
.withColumn(col.toString,
when(col.isNull || $"_card_rank" <= limit, col).otherwise(replaceWith))
.drop("_group_total")
.drop("_card_rank")
}
val df = sqlContext.read.json("/home/michael/events.json")
df.groupBy($"client_id", $"date", $"event_name")
.agg(sum($"event_count").as("event_count"))
.orderBy($"client_id", $"date", $"event_name").show()
//+---------+----------+----------+-----------+
//|client_id| date|event_name|event_count|
//+---------+----------+----------+-----------+
//| client1|2016-01-01| event_1| 2|
//| client1|2016-01-01| event_10| 18|
//| client1|2016-01-01| event_100| 17|
//| client1|2016-01-01| event_11| 4|
//| client1|2016-01-01| event_12| 12|
//| client1|2016-01-01| event_13| 14|
//| client1|2016-01-01| event_14| 15|
//| client1|2016-01-01| event_15| 10|
//| client1|2016-01-01| event_16| 9|
//| client1|2016-01-01| event_17| 22|
//| client1|2016-01-01| event_18| 4|
//| client1|2016-01-01| event_19| 8|
//| client1|2016-01-01| event_21| 9|
//| client1|2016-01-01| event_23| 15|
//| client1|2016-01-01| event_25| 10|
//| client1|2016-01-01| event_27| 7|
//| client1|2016-01-01| event_3| 6|
//| client1|2016-01-01| event_32| 12|
//| client1|2016-01-01| event_34| 9|
//| client1|2016-01-01| event_35| 13|
//+---------+----------+----------+-----------+
// only showing top 20 rows
// Leave 5 event names per client and day:
val protectedDF = protectDF(df, $"event_name", $"event_count",
Seq($"client_id", $"date"), 5, "Other")
// Show results:
protectedDF.groupBy($"client_id", $"date", $"event_name")
.agg(sum($"event_count").as("event_count"))
.orderBy($"client_id", $"date", $"event_name").show()
//+---------+----------+----------+-----------+
//|client_id| date|event_name|event_count|
//+---------+----------+----------+-----------+
//| client1|2016-01-01| Other| 579|
//| client1|2016-01-01| event_10| 18|
//| client1|2016-01-01| event_17| 22|
//| client1|2016-01-01| event_42| 27|
//| client1|2016-01-01| event_82| 20|
//| client1|2016-01-01| event_97| 19|
//| client1|2016-01-02| Other| 461|
//| client1|2016-01-02| event_15| 18|
//| client1|2016-01-02| event_26| 27|
//| client1|2016-01-02| event_73| 24|
//| client1|2016-01-02| event_76| 22|
//| client1|2016-01-02| event_9| 19|
//| client1|2016-01-03| Other| 517|
//| client1|2016-01-03| event_48| 28|
//| client1|2016-01-03| event_49| 19|
//| client1|2016-01-03| event_62| 22|
//| client1|2016-01-03| event_94| 21|
//| client1|2016-01-03| event_98| 25|
//| client2|2016-01-01| Other| 520|
//| client2|2016-01-01| event_18| 18|
//+---------+----------+----------+-----------+
//only showing top 20 rows
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment