Created
August 13, 2016 19:16
-
-
Save tsaastam/90f8e862a53d82a7942ec6a94dd0b7d0 to your computer and use it in GitHub Desktop.
Normalised Discounted Cumulative Gain (NDCG) for Spark DataFrames (with a UserDefinedAggregateFunction)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Normalised Discounted Cumulative Gain (NDCG) for Spark DataFrames | |
// See e.g. https://en.wikipedia.org/wiki/Discounted_cumulative_gain | |
// | |
// To run this code in the Spark Shell: | |
// | |
// 1) https://spark.apache.org/ -> download a binary Spark distribution | |
// 2) ./bin/spark-shell | |
// 3) copy-paste! | |
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
import org.apache.spark.sql.Row | |
import org.apache.spark.sql.types._ | |
// Let's say you have some search results data, a bit like this: | |
val schema = new StructType(Array( | |
StructField("searchId", LongType), | |
StructField("timestamp", LongType), | |
StructField("resultUrl", StringType), | |
StructField("position", IntegerType), | |
StructField("clicked", IntegerType), | |
StructField("converted", IntegerType), | |
StructField("relevanceScore", DoubleType))) | |
val data = sc.parallelize(Seq( | |
Row(123L, 1471097840569L, "https://some.site/", 1, 1, 0, 1.28), | |
Row(123L, 1471097840569L, "https://another.site/", 2, 0, 0, 2.3001), | |
Row(123L, 1471097840569L, "https://yet.another.site/", 3, 0, 0, 0.792), | |
Row(123L, 1471097840569L, "https://a.relevant.site/", 4, 1, 1, 1.51), | |
Row(456L, 1471102902205L, "https://another.search/", 1, 0, 0, 0.07), | |
Row(456L, 1471102902205L, "https://another.result/", 2, 0, 0, 0.04), | |
Row(456L, 1471102902205L, "https://another.site/", 3, 1, 0, 0.02) | |
)) | |
val df = sqlContext.createDataFrame(data, schema) | |
// Discounted cumulative gain (DCG), non-normalised, is easy to calculate: | |
df.groupBy($"searchId").agg(sum($"relevanceScore"/log(2.0, $"position"+1)).as("DCG")).show | |
// +--------+------------------+ | |
// |searchId| DCG| | |
// +--------+------------------+ | |
// | 456|0.1052371901428583| | |
// | 123|3.7775231288805324| | |
// +--------+------------------+ | |
// but the problem with this is that the DCG is not normalised, so it's difficult to compare | |
// the DCG value across searches. To solve this we can use normalised DCG (NDCG): | |
object NDCG extends UserDefinedAggregateFunction { | |
def inputSchema = new StructType() | |
.add("position", DoubleType) | |
.add("relevance", DoubleType) | |
def bufferSchema = new StructType() | |
.add("positions", ArrayType(DoubleType, false)) | |
.add("relevances", ArrayType(DoubleType, false)) | |
def dataType = DoubleType | |
def deterministic = true | |
def initialize(buffer: MutableAggregationBuffer) = { | |
buffer(0) = IndexedSeq[Double]() | |
buffer(1) = IndexedSeq[Double]() | |
} | |
def update(buffer: MutableAggregationBuffer, input: Row) = { | |
if(!input.isNullAt(0) && !input.isNullAt(1)) { | |
val (position, relevance) = (input.getDouble(0), input.getDouble(1)) | |
buffer(0) = buffer.getAs[IndexedSeq[Double]](0) :+ position | |
buffer(1) = buffer.getAs[IndexedSeq[Double]](1) :+ relevance | |
} | |
} | |
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { | |
if(!buffer2.isNullAt(0) && !buffer2.isNullAt(1)) { | |
buffer1(0) = buffer1.getAs[IndexedSeq[Double]](0) ++ | |
buffer2.getAs[IndexedSeq[Double]](0) | |
buffer1(1) = buffer1.getAs[IndexedSeq[Double]](1) ++ | |
buffer2.getAs[IndexedSeq[Double]](1) | |
} | |
} | |
private def totalGain(scores: Seq[(Double, Double)]): Double = { | |
val (_, gain) = scores.foldLeft((1, 0.0))( | |
(fa, tuple) => tuple match { case (_, score) => | |
if(score <= 0.0) (fa._1+1, fa._2) | |
else if(fa._1 == 1) (fa._1+1, fa._2+score) | |
else (fa._1+1, fa._2+score/(Math.log(fa._1+1)/Math.log(2.0))) | |
}) | |
gain | |
} | |
def evaluate(buffer: Row) = { | |
val (positions, relevances) = (buffer.getAs[IndexedSeq[Double]](0), buffer.getAs[IndexedSeq[Double]](1)) | |
val scores = (positions, relevances).zipped.toList.sorted | |
val ideal = scores.map(_._2).filter(_>0).sortWith(_>_).zipWithIndex.map { case (s,i0) => (i0+1.0,s) } | |
val (thisScore, idealScore) = (totalGain(scores), totalGain(ideal)) | |
// println(s"scores $scores -> $thisScore\nideal $ideal -> $idealScore") | |
if(idealScore == 0.0) 0.0 else thisScore / idealScore | |
} | |
} | |
// How to use it: | |
df.groupBy($"searchId").agg(NDCG($"position", $"relevanceScore").as("NDCG")).show | |
// +--------+------------------+ | |
// |searchId| NDCG| | |
// +--------+------------------+ | |
// | 456| 1.0| | |
// | 123|0.8922089188046599| | |
// +--------+------------------+ | |
// Of course one can use other definitions of relevance as well, e.g. clicks + 3*conversions: | |
df.groupBy($"searchId").agg(sum(($"clicked"+$"converted".cast(DoubleType)*3.0)/log(2.0, $"position"+1)).as("DCG")).show | |
df.groupBy($"searchId").agg(NDCG($"position", $"clicked"+$"converted".cast(DoubleType)*3.0).as("NDCG")).show |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Do you know where I could find an implementation for this in pyspark?