Skip to content

Instantly share code, notes, and snippets.

@dyangrev
Created October 4, 2016 20:17
Show Gist options
  • Save dyangrev/ed9b6f05169ee3a392004d402f536693 to your computer and use it in GitHub Desktop.
Save dyangrev/ed9b6f05169ee3a392004d402f536693 to your computer and use it in GitHub Desktop.
import org.apache.spark.sql.DataFrame
import com.rockymadden.stringmetric.similarity.NGramMetric
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.functions._
import sqlContext.implicits._
import java.util.UUID
import com.datastax.spark.connector._
object BlockingUtils {
object MatchingConfig {
// Matching configs
case class NGramMethod(gram: Int)
case class ExactMethod()
val config = Map(
"normalized_first_name" -> NGramMethod(2),
"normalized_middle_name" -> NGramMethod(2),
"normalized_last_name" -> NGramMethod(2),
"normalized_postal_address" -> NGramMethod(2),
"normalized_email_address" -> NGramMethod(2),
"normalized_phone_number" -> ExactMethod,
"normalized_date_of_birth" -> ExactMethod,
"normalized_primary_language" -> ExactMethod,
"normalized_gender" -> ExactMethod)
val fieldsToCompare = Seq(
"normalized_first_name",
"normalized_middle_name",
"normalized_last_name",
"normalized_postal_address",
"normalized_email_address",
"normalized_phone_number",
"normalized_date_of_birth",
"normalized_primary_language",
"normalized_gender")
// End of matching configs
}
object UDFS {
val uuidUdf = udf(() => UUID.randomUUID().toString)
def applyComparison[T]: (String, String, (String, String) => T) => Option[T] =
(str1: String, str2: String, comparisonMethod: (String, String) => T) => {
def isValid = (str: String) => Option(str).exists(s => s.nonEmpty && s != "null")
if (isValid(str1) && isValid(str2)) Some(comparisonMethod(str1, str2))
else None
}
val nGramUdf = (gram: Int) => udf[Double, String, String] {
(str1: String, str2: String) => {
val matcher = (s1: String, s2: String) => NGramMetric(gram).compare(s1, s2).getOrElse(0.0)
applyComparison(str1, str2, matcher).getOrElse(-1.0)
}
}
val exactUdf = udf[Double, String, String] {
(str1: String, str2: String) => {
val matcher = (s1: String, s2: String) => if (s1 == s2) 1 else 0
applyComparison(str1, str2, matcher).map(_.toDouble).getOrElse(-1.0)
}
}
// If both field x columns are populated, size++, sum+=x_sim
val meanUdf = udf[Double, Double, Double, Double, Double, Double, Double, Double, Double, Double] {
(d1, d2, d3, d4, d5, d6, d7, d8, d9) => {
Seq(d1, d2, d3, d4, d5, d6, d7, d8, d9).filter(_ != -1.0) match {
case filteredList if filteredList.nonEmpty => filteredList.sum / filteredList.size
case _ => 0.0
}
}
}
}
/**
* Load a cassandra table into a dataframe
*
* @param tableName Name of the table
* @param keyspace Name of the keyspace
* @param cluster Name of the cluster
*/
def loadTableIntoDF(tableName: String, keyspace: String = "doppler", cluster: String = "cpark") =
sqlContext
.read
.format("org.apache.spark.sql.cassandra")
.options(Map(
"table" -> tableName,
"keyspace" -> keyspace,
"cluster" -> cluster)).load()
/**
* Create a blocked and joined data frame
*
* @param df Dataframe to operate on, usually the blocking normalized profile data frame
* @param blockingKey Blocking key, if there are more than one, all of them will be treated as a composite blocking key
* @return
*/
def blockedAndJoinedDF(df: DataFrame, blockingKey: Seq[String]) = {
def addPostfixForDFColumns(cols: Iterable[String], df: DataFrame, postfix: String) =
cols.foldLeft(df)((df, colName) => df.withColumnRenamed(colName, s"$colName$postfix"))
val joinExpr = blockingKey.foldLeft($"profile_id_1" < $"profile_id_2")((expr, blockingKey) =>
expr && $"${blockingKey}_1" === $"${blockingKey}_2"
&& !isnull($"${blockingKey}_1")
&& !isnull($"${blockingKey}_2")
&& ($"${blockingKey}_1" !== lit("null"))
&& ($"${blockingKey}_2" !== lit("null"))
)
val cols = df.columns
val df1 = addPostfixForDFColumns(cols, df, "_1")
val df2 = addPostfixForDFColumns(cols, df, "_2")
df1.join(df2, joinExpr)
}
/**
* Calculate the similarity for blocked and joined data frame
*
* @param blockedAndJoinedDF The blocked and joined data frame
*/
def calculateSimilarityForBlockedAndJoinedDF(blockedAndJoinedDF: DataFrame) = {
val matchedDF = MatchingConfig.fieldsToCompare.foldLeft(blockedAndJoinedDF)((df, fieldToCompare) => {
val config = MatchingConfig.config.get(fieldToCompare)
config match {
case Some(MatchingConfig.NGramMethod(gram)) =>
df.withColumn(s"${fieldToCompare}_sim", UDFS.nGramUdf(gram)(df(s"${fieldToCompare}_1"), df(s"${fieldToCompare}_2")))
case Some(MatchingConfig.ExactMethod) =>
df.withColumn(s"${fieldToCompare}_sim", UDFS.exactUdf(df(s"${fieldToCompare}_1"), df(s"${fieldToCompare}_2")))
case _ =>
df
}
})
// Calculate the mean sim score
matchedDF.withColumn("mean_sim", UDFS.meanUdf(
matchedDF("normalized_first_name_sim"),
matchedDF("normalized_middle_name_sim"),
matchedDF("normalized_last_name_sim"),
matchedDF("normalized_postal_address_sim"),
matchedDF("normalized_email_address_sim"),
matchedDF("normalized_phone_number_sim"),
matchedDF("normalized_date_of_birth_sim"),
matchedDF("normalized_primary_language_sim"),
matchedDF("normalized_gender_sim")))
}
/**
* Persist a data frame into cassandar
*
* @param df The dataframe to be persisted
* @param keyspace Keyspace to persist to
* @param table Table to persist to
* @param saveMode Save mode
*/
def persistDataFrame(df: DataFrame, keyspace: String, table: String, saveMode: SaveMode = SaveMode.Append) {
df.write
.mode(SaveMode.Append)
.format("org.apache.spark.sql.cassandra")
.options(Map("keyspace" -> keyspace, "table" -> table))
.save()
}
/**
* Load normalized profiles into a dataframe and cache it
*/
def loadNormalizedProfiles = {
loadTableIntoDF("normalized_profiles")
.drop("profile_id")
.distinct // Drop the profile id and do a distinct for deduping
.withColumn("profile_id", BlockingUtils.UDFS.uuidUdf()) // Then we add a new profile id
.select(
"property",
"profile_id",
"normalized_first_name",
"normalized_middle_name",
"normalized_last_name",
"normalized_email_address",
"normalized_phone_number",
"normalized_first_name_first_3",
"normalized_last_name_first_3",
"normalized_postal_address",
"normalized_date_of_birth",
"normalized_primary_language",
"normalized_gender")
.cache()
}
def getBlockingKeysByProperty(property: String): Seq[Seq[String]] = property match {
case "30100" => Seq(
Seq("normalized_first_name"),
Seq("normalized_last_name", "normalized_first_name_first_3"),
Seq("normalized_email_address"),
Seq("normalized_phone_number"),
Seq("normalized_last_name_first_3", "normalized_postal_address")
)
case x if Seq("GOVERN", "LUCIA").contains(x) => Seq(
Seq("normalized_first_name", "normalized_last_name_first_3"),
Seq("normalized_last_name"),
Seq("normalized_email_address"),
Seq("normalized_phone_number"),
Seq("normalized_last_name_first_3", "normalized_postal_address")
)
case _ => Seq(
Seq("normalized_first_name"),
Seq("normalized_last_name"),
Seq("normalized_email_address"),
Seq("normalized_phone_number"),
Seq("normalized_last_name_first_3", "normalized_postal_address")
)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment