Skip to content

Instantly share code, notes, and snippets.

@yifeihuang
Last active September 13, 2020 20:08
Show Gist options
  • Save yifeihuang/826b9110add28d232e6d30d571bde30c to your computer and use it in GitHub Desktop.
Save yifeihuang/826b9110add28d232e6d30d571bde30c to your computer and use it in GitHub Desktop.
Initial candidate pair match scoring
@udf("double")
def dot(x, y):
if x is not None and y is not None:
return float(x.dot(y))
else:
return 0
def null_safe_levenshtein_sim(c1, c2):
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\
.otherwise(1 - f.levenshtein(c1, c2) / f.greatest(f.length(c1), f.length(c2)))
return output
def null_safe_num_sim(c1, c2):
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\
.when((f.col(c1) == 0) & (f.col(c2) == 0), 1)\
.when((f.col(c1) == 0) | (f.col(c2) == 0), 0)\
.otherwise(1 - f.abs(f.col(c1) - f.col(c2)) / f.greatest(c1, c2))
return output
def null_safe_token_overlap(c1, c2):
# is the overlap a significant part of the shorter string
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\
.when((f.size(f.array_distinct(c1)) == 0) | (f.size(f.array_distinct(c2)) == 0), 0)\
.otherwise(f.size(f.array_intersect(c1, c2)) / f.least(f.size(f.array_distinct(c1)), f.size(f.array_distinct(c1))))
return output
def calc_sim(df):
df = df.withColumn('name_lev', null_safe_levenshtein_sim('src.name', 'dst.name'))\
.withColumn('manufacturer_lev', null_safe_levenshtein_sim('src.manufacturer', 'dst.manufacturer'))\
.withColumn('description_lev', null_safe_levenshtein_sim('src.description', 'dst.description'))\
.withColumn('name_token_sim', null_safe_token_overlap('src.name_swRemoved', 'dst.name_swRemoved'))\
.withColumn('manufacturer_token_sim', null_safe_token_overlap('src.manufacturer_swRemoved', 'dst.manufacturer_swRemoved'))\
.withColumn('description_token_sim', null_safe_token_overlap('src.description_swRemoved', 'dst.description_swRemoved'))\
.withColumn('price_sim', null_safe_num_sim('src.price', 'dst.price'))\
.withColumn('name_tfidf_sim', dot(f.col('src.name_swRemoved_tfidf'), f.col('dst.name_swRemoved_tfidf')))\
.withColumn('description_tfidf_sim', dot(f.col('src.description_swRemoved_tfidf'), f.col('dst.description_swRemoved_tfidf')))\
.withColumn('manufacturer_tfidf_sim', dot(f.col('src.manufacturer_swRemoved_tfidf'), f.col('dst.manufacturer_swRemoved_tfidf')))\
.withColumn('name_encoding_sim', dot(f.col('src.name_encoding'), f.col('dst.name_encoding')))\
.withColumn('description_encoding_sim', dot(f.col('src.description_encoding'), f.col('dst.description_encoding')))
metrics = ['manufacturer_lev', 'description_lev', 'name_lev', 'price_sim', 'name_tfidf_sim', 'description_tfidf_sim',
'manufacturer_tfidf_sim', 'name_encoding_sim', 'description_encoding_sim',
'name_token_sim', 'manufacturer_token_sim', 'description_token_sim'
]
df = df.withColumn('overall_sim', reduce(add, [f.col(c) for c in metrics]) / len(metrics))
return df
distance_df = calc_sim(g.triplets)
distance_df.write.mode('overwrite').parquet("YOUR_STORAGE_PATH/amazon_google_distance.parquet")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment