This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
df = google.select( | |
f.lit('google').alias('source'), | |
f.col('id').alias('source_id'), | |
f.col('name'), f.col('description'), | |
f.col('manufacturer'), | |
f.col('price') | |
)\ | |
.union( | |
amazon.select( | |
f.lit('amazon').alias('source'), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import functions as f | |
from pyspark.sql import types as t | |
from pyspark.sql import Window as w | |
from pyspark.ml.linalg import DenseVector, SparseVector | |
from pyspark.ml.feature import HashingTF, IDF, Tokenizer, RegexTokenizer, CountVectorizer, StopWordsRemover, NGram, Normalizer, VectorAssembler, Word2Vec, Word2VecModel, PCA | |
from pyspark.ml import Pipeline, Transformer | |
from pyspark.ml.linalg import VectorUDT, Vectors | |
import tensorflow_hub as hub |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql import functions as f | |
from pyspark.sql import types as t | |
from pyspark.sql import Window as w | |
import numpy as np | |
from graphframes import GraphFrame | |
keep_cols = ['source', 'name', 'description', 'manufacturer', 'price', | |
'name_swRemoved', 'description_swRemoved', 'manufacturer_swRemoved', | |
'name_swRemoved_tfidf', 'description_swRemoved_tfidf', 'manufacturer_swRemoved_tfidf', | |
'name_encoding', 'description_encoding'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@udf("double") | |
def dot(x, y): | |
if x is not None and y is not None: | |
return float(x.dot(y)) | |
else: | |
return 0 | |
def null_safe_levenshtein_sim(c1, c2): | |
output = f.when(f.col(c1).isNull() | f.col(c2).isNull(), 0)\ | |
.otherwise(1 - f.levenshtein(c1, c2) / f.greatest(f.length(c1), f.length(c2))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
distance_df = spark.read.parquet("YOUR_STORAGE_PATH/amazon_google_distance.parquet") | |
display_cols = ['name', 'description', 'manufacturer', 'price'] | |
sample_df = distance_df.filter((f.col('overall_sim') > 0) & (f.col('overall_sim') < 1)) | |
.select('edge.src', 'edge.dst', *[f.concat_ws('\nVS\n', 'src.' + c, 'dst.' + c).alias(c) for c in display_cols], 'overall_sim') | |
.sample(withReplacement=False, fraction=0.02, seed=42) | |
sample_df.write.mode('overwrite').csv("YOUR_STORAGE_PATH/candidate_pair_sample.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
human_label = spark.read.csv("YOUR_STORAGE_PATH/candidate_pair_sample_LABELED.csv")\ | |
.filter(f.col('human_label').isNotNull())\ | |
.distinct() | |
feature_df = distance_df.filter(f.col('overall_sim') > 0.06)\ | |
.withColumn('rules_label', | |
f.when((f.col('name_tfidf_sim') >= 0.999) | (f.col('overall_sim') >= 0.999), 1) | |
.when(f.col('overall_sim') < 0.12, 0) | |
.otherwise(None) | |
)\ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pyspark.sql.functions import pandas_udf | |
import pandas as pd | |
@pandas_udf(returnType=t.DoubleType()) | |
def pd_predict(feature): | |
temp = feature.values.tolist() | |
return pd.Series(gs_rf.best_estimator_.predict_proba(temp)[:,1]) | |
output_df = feature_df.withColumn('prob', pd_predict('features')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
strong_edges = output_df.filter(f.col('prob') >= 0.5)\ | |
.select('edge.src', 'edge.dst') | |
strong_graph = GraphFrame(node, strong_edges) | |
spark.sparkContext.setCheckpointDir("/tmp/match_checkpoints") | |
comps = strong_graph.connectedComponents()\ | |
.select('component', 'source', f.col('id').alias('source_id')) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
OlderNewer