Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@jesusjavierdediego
Last active February 27, 2019 22:41
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jesusjavierdediego/9d719d6a5d36ad786801e2f8a8c11de5 to your computer and use it in GitHub Desktop.
Save jesusjavierdediego/9d719d6a5d36ad786801e2f8a8c11de5 to your computer and use it in GitHub Desktop.
def applyModelToAllCombinations(trainedModel: LogisticRegressionModel, allComparableDataset: Dataset[(Person, Person, Vector)]): Dataset[PredictedVector] ={
import spark.implicits._
val getFirst = udf((v: Vector) => v(1))
val predictionsRaw: DataFrame = trainedModel.transform(allComparableDataset)
predictionsRaw.select(
$"left.old_id".as("id_left"),
$"right.old_id".as("id_right"),
$"features",
getFirst($"probability").as("probability"),
$"prediction".as("label")
)
.filter('label === 1.0)
.as[PredictedVector]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment