Last active
September 13, 2020 20:08
-
-
Save yifeihuang/826b9110add28d232e6d30d571bde30c to your computer and use it in GitHub Desktop.
Initial candidate pair match scoring
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@udf("double") | |
def dot(x, y): | |
if x is not None and y is not None: | |
return float(x.dot(y)) | |
else: | |
return 0 | |
def null_safe_levenshtein_sim(c1, c2): | |
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\ | |
.otherwise(1 - f.levenshtein(c1, c2) / f.greatest(f.length(c1), f.length(c2))) | |
return output | |
def null_safe_num_sim(c1, c2): | |
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\ | |
.when((f.col(c1) == 0) & (f.col(c2) == 0), 1)\ | |
.when((f.col(c1) == 0) | (f.col(c2) == 0), 0)\ | |
.otherwise(1 - f.abs(f.col(c1) - f.col(c2)) / f.greatest(c1, c2)) | |
return output | |
def null_safe_token_overlap(c1, c2): | |
# is the overlap a significant part of the shorter string | |
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\ | |
.when((f.size(f.array_distinct(c1)) == 0) | (f.size(f.array_distinct(c2)) == 0), 0)\ | |
.otherwise(f.size(f.array_intersect(c1, c2)) / f.least(f.size(f.array_distinct(c1)), f.size(f.array_distinct(c1)))) | |
return output | |
def calc_sim(df): | |
df = df.withColumn('name_lev', null_safe_levenshtein_sim('src.name', 'dst.name'))\ | |
.withColumn('manufacturer_lev', null_safe_levenshtein_sim('src.manufacturer', 'dst.manufacturer'))\ | |
.withColumn('description_lev', null_safe_levenshtein_sim('src.description', 'dst.description'))\ | |
.withColumn('name_token_sim', null_safe_token_overlap('src.name_swRemoved', 'dst.name_swRemoved'))\ | |
.withColumn('manufacturer_token_sim', null_safe_token_overlap('src.manufacturer_swRemoved', 'dst.manufacturer_swRemoved'))\ | |
.withColumn('description_token_sim', null_safe_token_overlap('src.description_swRemoved', 'dst.description_swRemoved'))\ | |
.withColumn('price_sim', null_safe_num_sim('src.price', 'dst.price'))\ | |
.withColumn('name_tfidf_sim', dot(f.col('src.name_swRemoved_tfidf'), f.col('dst.name_swRemoved_tfidf')))\ | |
.withColumn('description_tfidf_sim', dot(f.col('src.description_swRemoved_tfidf'), f.col('dst.description_swRemoved_tfidf')))\ | |
.withColumn('manufacturer_tfidf_sim', dot(f.col('src.manufacturer_swRemoved_tfidf'), f.col('dst.manufacturer_swRemoved_tfidf')))\ | |
.withColumn('name_encoding_sim', dot(f.col('src.name_encoding'), f.col('dst.name_encoding')))\ | |
.withColumn('description_encoding_sim', dot(f.col('src.description_encoding'), f.col('dst.description_encoding'))) | |
metrics = ['manufacturer_lev', 'description_lev', 'name_lev', 'price_sim', 'name_tfidf_sim', 'description_tfidf_sim', | |
'manufacturer_tfidf_sim', 'name_encoding_sim', 'description_encoding_sim', | |
'name_token_sim', 'manufacturer_token_sim', 'description_token_sim' | |
] | |
df = df.withColumn('overall_sim', reduce(add, [f.col(c) for c in metrics]) / len(metrics)) | |
return df | |
distance_df = calc_sim(g.triplets) | |
distance_df.write.mode('overwrite').parquet("YOUR_STORAGE_PATH/amazon_google_distance.parquet") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment