Skip to content

Instantly share code, notes, and snippets.

@kell18
Created April 3, 2016 16:02
Show Gist options
  • Save kell18/224dedf0a81cbdff28cd01287bb585f9 to your computer and use it in GitHub Desktop.
Save kell18/224dedf0a81cbdff28cd01287bb585f9 to your computer and use it in GitHub Desktop.
package ir.task8
import ir.utils.Files
import scala.util.{Failure, Success}
case class Doc(id: String)
case class Query(actual: List[Doc], predicted: Array[Doc])
class Metrics(val k: Int) {
def map4k(queries: List[Query]) = queries.map(q => ap4k(q.actual, q.predicted)).sum / queries.length
def ap4k(actual: List[Doc], predicted: Array[Doc]) = scores(actual, predicted).sum / Math.min(k, actual.length)
def precision4k(actual: List[Doc], predicted: Array[Doc]) = scores(actual, predicted)(k-1)
def mrr4k(queries: List[Query]) = queries.map(q => reciprocal(q.actual, q.predicted)).sum / queries.length
def scores(actual: List[Doc], predicted: Array[Doc]) = {
var relevant = 0.0
for {
i <- 0 until Math.min(k, predicted.length)
} yield {
if (actual.contains(predicted(i))) {
relevant += 1
relevant / (i+1)
} else 0.0
}
}
def reciprocal(actual: List[Doc], predicted: Array[Doc]) = {
val rank = predicted.indexOf(actual.head)
require(rank != -1, { "Predicted and actual contains different data." })
1.0 / (rank + 1.0)
}
}
object Task8 extends App {
def mkQueries(actual: Array[Doc], predict: Array[Doc]) = {
actual.groupBy(d => d.id).values.zip(predict.groupBy(p => p.id).values).map(x => Query(x._1.toList, x._2))
}
def loadDocs(path: String) = for {
file <- Files.readFile(path)
} yield for {
l <- file.split("\\r?\\n")
i = l.split("\\t+|\\s+")
if i.nonEmpty
} yield Doc(i(0))
}
package ir
import ir.task8._
import org.specs2.mutable.Specification
import scala.util.{Failure, Success}
class MetricsSpecs extends Specification {
val mapQs = Query(List(Doc("1"), Doc("2"), Doc("3"), Doc("4"), Doc("5")),
Array(Doc("6"), Doc("4"), Doc("7"), Doc("1"), Doc("2"))) ::
Query(List(Doc("6"), Doc("2"), Doc("3"), Doc("4"), Doc("5")),
Array(Doc("6"), Doc("2"), Doc("3"), Doc("4"), Doc("5"))) :: Nil
val mrrQs = Query(List(Doc("3")), Array(Doc("1"), Doc("2"), Doc("3"))) ::
Query(List(Doc("2")), Array(Doc("1"), Doc("2"), Doc("3"))) ::
Query(List(Doc("1")), Array(Doc("1"), Doc("2"), Doc("3"))) :: Nil
"MetricsSpecs" should {
"precision@2" in {
val k = 2
val metrics = new Metrics(k)
val p4k = metrics.precision4k(mapQs.head.actual, mapQs.head.predicted)
p4k must beCloseTo(0.5, 0.001)
}
"ap@2" in {
val k = 2
val metrics = new Metrics(k)
val ap4k = metrics.ap4k(mapQs.head.actual, mapQs.head.predicted)
println(s"ap4k $k = $ap4k")
ap4k must beCloseTo(0.25, 0.001) // http://fastml.com/what-you-wanted-to-know-about-mean-average-precision/
}
"map@2" in {
val k = 2
val metrics = new Metrics(k)
val map4k = metrics.map4k(mapQs)
println(s"map4k $k = $map4k")
map4k must beCloseTo(0.625, 0.001) // http://fastml.com/what-you-wanted-to-know-about-mean-average-precision/
}
"mrr@2" in {
val k = 2
val metrics = new Metrics(k)
val mrr4k = metrics.mrr4k(mrrQs)
println(s"mrr4k $k = $mrr4k")
mrr4k must beCloseTo(0.61111, 0.001) // https://en.wikipedia.org/wiki/Mean_reciprocal_rank
}
/*"Make queries" in {
val queries = for {
actual <- loadDocs(Files.testResources + "/qrels_actual.txt")
random <- loadDocs(Files.testResources +"/qrels_rand.txt")
} yield mkQueries(actual, random)
queries match {
case Success(qs) => println("Data loaded")
case Failure(ex) => println("Fail load queries: " + ex.getMessage)
}
queries.get.flatMap(_.actual) must contain(Doc("INEX_LD-2009022"))
queries.get.flatMap(_.predicted) must not contain Doc("__None__")
}*/
/*"map@10" in {
import java.util
val k = 10
val metrics = new Metrics(k)
val queries = for {
actual <- loadDocs(Files.testResources + "/qrels_actual.txt")
random <- loadDocs(Files.testResources +"/qrels_rand.txt")
} yield mkQueries(actual, scala.util.Random.shuffle(random.toSeq).toArray)
queries match {
case Success(qs) => println("Data loaded")
case Failure(ex) => println("Fail load queries: " + ex.getMessage)
}
println("map act: " + metrics.map4k(queries.get.toList))
println("mrr: " + metrics.mrr4k(queries.get.toList))
1 mustEqual 1
}*/
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment