Skip to content

Instantly share code, notes, and snippets.

@tsaastam
Created August 13, 2016 19:16
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tsaastam/90f8e862a53d82a7942ec6a94dd0b7d0 to your computer and use it in GitHub Desktop.
Save tsaastam/90f8e862a53d82a7942ec6a94dd0b7d0 to your computer and use it in GitHub Desktop.
Normalised Discounted Cumulative Gain (NDCG) for Spark DataFrames (with a UserDefinedAggregateFunction)
// Normalised Discounted Cumulative Gain (NDCG) for Spark DataFrames
// See e.g. https://en.wikipedia.org/wiki/Discounted_cumulative_gain
//
// To run this code in the Spark Shell:
//
// 1) https://spark.apache.org/ -> download a binary Spark distribution
// 2) ./bin/spark-shell
// 3) copy-paste!
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
// Let's say you have some search results data, a bit like this:
val schema = new StructType(Array(
StructField("searchId", LongType),
StructField("timestamp", LongType),
StructField("resultUrl", StringType),
StructField("position", IntegerType),
StructField("clicked", IntegerType),
StructField("converted", IntegerType),
StructField("relevanceScore", DoubleType)))
val data = sc.parallelize(Seq(
Row(123L, 1471097840569L, "https://some.site/", 1, 1, 0, 1.28),
Row(123L, 1471097840569L, "https://another.site/", 2, 0, 0, 2.3001),
Row(123L, 1471097840569L, "https://yet.another.site/", 3, 0, 0, 0.792),
Row(123L, 1471097840569L, "https://a.relevant.site/", 4, 1, 1, 1.51),
Row(456L, 1471102902205L, "https://another.search/", 1, 0, 0, 0.07),
Row(456L, 1471102902205L, "https://another.result/", 2, 0, 0, 0.04),
Row(456L, 1471102902205L, "https://another.site/", 3, 1, 0, 0.02)
))
val df = sqlContext.createDataFrame(data, schema)
// Discounted cumulative gain (DCG), non-normalised, is easy to calculate:
df.groupBy($"searchId").agg(sum($"relevanceScore"/log(2.0, $"position"+1)).as("DCG")).show
// +--------+------------------+
// |searchId| DCG|
// +--------+------------------+
// | 456|0.1052371901428583|
// | 123|3.7775231288805324|
// +--------+------------------+
// but the problem with this is that the DCG is not normalised, so it's difficult to compare
// the DCG value across searches. To solve this we can use normalised DCG (NDCG):
object NDCG extends UserDefinedAggregateFunction {
def inputSchema = new StructType()
.add("position", DoubleType)
.add("relevance", DoubleType)
def bufferSchema = new StructType()
.add("positions", ArrayType(DoubleType, false))
.add("relevances", ArrayType(DoubleType, false))
def dataType = DoubleType
def deterministic = true
def initialize(buffer: MutableAggregationBuffer) = {
buffer(0) = IndexedSeq[Double]()
buffer(1) = IndexedSeq[Double]()
}
def update(buffer: MutableAggregationBuffer, input: Row) = {
if(!input.isNullAt(0) && !input.isNullAt(1)) {
val (position, relevance) = (input.getDouble(0), input.getDouble(1))
buffer(0) = buffer.getAs[IndexedSeq[Double]](0) :+ position
buffer(1) = buffer.getAs[IndexedSeq[Double]](1) :+ relevance
}
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
if(!buffer2.isNullAt(0) && !buffer2.isNullAt(1)) {
buffer1(0) = buffer1.getAs[IndexedSeq[Double]](0) ++
buffer2.getAs[IndexedSeq[Double]](0)
buffer1(1) = buffer1.getAs[IndexedSeq[Double]](1) ++
buffer2.getAs[IndexedSeq[Double]](1)
}
}
private def totalGain(scores: Seq[(Double, Double)]): Double = {
val (_, gain) = scores.foldLeft((1, 0.0))(
(fa, tuple) => tuple match { case (_, score) =>
if(score <= 0.0) (fa._1+1, fa._2)
else if(fa._1 == 1) (fa._1+1, fa._2+score)
else (fa._1+1, fa._2+score/(Math.log(fa._1+1)/Math.log(2.0)))
})
gain
}
def evaluate(buffer: Row) = {
val (positions, relevances) = (buffer.getAs[IndexedSeq[Double]](0), buffer.getAs[IndexedSeq[Double]](1))
val scores = (positions, relevances).zipped.toList.sorted
val ideal = scores.map(_._2).filter(_>0).sortWith(_>_).zipWithIndex.map { case (s,i0) => (i0+1.0,s) }
val (thisScore, idealScore) = (totalGain(scores), totalGain(ideal))
// println(s"scores $scores -> $thisScore\nideal $ideal -> $idealScore")
if(idealScore == 0.0) 0.0 else thisScore / idealScore
}
}
// How to use it:
df.groupBy($"searchId").agg(NDCG($"position", $"relevanceScore").as("NDCG")).show
// +--------+------------------+
// |searchId| NDCG|
// +--------+------------------+
// | 456| 1.0|
// | 123|0.8922089188046599|
// +--------+------------------+
// Of course one can use other definitions of relevance as well, e.g. clicks + 3*conversions:
df.groupBy($"searchId").agg(sum(($"clicked"+$"converted".cast(DoubleType)*3.0)/log(2.0, $"position"+1)).as("DCG")).show
df.groupBy($"searchId").agg(NDCG($"position", $"clicked"+$"converted".cast(DoubleType)*3.0).as("NDCG")).show
@dhruvix
Copy link

dhruvix commented Feb 27, 2021

Do you know where I could find an implementation for this in pyspark?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment