Skip to content

Instantly share code, notes, and snippets.

@schaunwheeler
Created April 13, 2019 11:44
Show Gist options
  • Save schaunwheeler/fc91fa09267dc7ecb46d494ba5fc4c94 to your computer and use it in GitHub Desktop.
Save schaunwheeler/fc91fa09267dc7ecb46d494ba5fc4c94 to your computer and use it in GitHub Desktop.
Example of using spaCy on Spark
import pyspark.sql.types as t
import pyspark.sql.functions as f
def spacy_word2vec_grouped(cat_list, id_col, string_col):
"""
Example usage:
vec_sdf = (
sdf
.select('idColumn', 'documentText')
.groupby((f.floor(f.rand() * 20)).alias('groupNumber'))
.agg(f.collect_list(f.struct(f.col('idColumn'), f.col('documentText'))).alias('documentGroup'))
.repartition('groupNumber')
.select(f.explode(spacy_word2vec_grouped_udf(f.col('documentGroup'))).alias('results'))
.select(f.col('results.*'))
)
"""
import spacy
nlp = spacy.load('en_core_web_lg')
output = list()
for cat in cat_list:
doc = nlp(cat[string_col])
vector = doc.vector.tolist()
if sum(vector) != 0.0:
output.append((cat[id_col], cat[string_col], vector))
return output
spacy_word2vec_grouped_udf = f.udf(
spacy_word2vec_grouped,
t.ArrayType(
t.StructType([
t.StructField('idColumn', t.LongType()),
t.StructField('documentText', t.StringType()),
t.StructField('documentVector', t.ArrayType(t.DoubleType()))
])
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment