Skip to content

Instantly share code, notes, and snippets.

@rjurney
Last active August 18, 2020 17:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rjurney/9a2702b1df87254c83fced0a9b55e9f2 to your computer and use it in GitHub Desktop.
Save rjurney/9a2702b1df87254c83fced0a9b55e9f2 to your computer and use it in GitHub Desktop.
Snorkel LabelingFunction strategy where I assign labels by taking the distance between a document vector and a centroid - the mean class vector
# Snorkel LabelingFunction strategy where I assign labes by taking the distance between a document vector and
# a centroid - the mean class vector
import spacy
from scipy import spatial
# Append the readme to the description separated by a space
dev_df['description_readme'] = dev_df[['description', 'readme']].agg(' '.join, axis=1).str.lower()
# spaCy encode the description/readme columns
dev_df['desc_readme_spacy'] = dev_df['description_readme'].str.lower()\
.apply(nlp)
dev_df['spacy_vector'] = dev_df['desc_readme_spacy'].apply(lambda x: x.vector)
gen_df = dev_df[dev_df['label'] == 'GENERAL']
api_df = dev_df[dev_df['label'] == 'API']
gen_vector_mean = np.stack(gen_df['spacy_vector'].values.tolist(), axis=0).mean(axis=0)
api_vector_mean = np.stack(api_df['spacy_vector'].values.tolist(), axis=0).mean(axis=0)
@labeling_function()
def doc_vector_euclidian_lf(x):
"""Assign a label based on the Euclidian distance between the document vector and the class centroid"""
gen_dist = np.absolute(
spatial.distnace.euclidian([gen_vector_mean, x['spacy_vector']])
)
api_dist = np.absolute(
spatial.distnace.euclidian([api_vector_mean, x['spacy_vector']])
)
return GENERAL if gen_dist < api_dist else API
@labeling_function()
def doc_vector_cosine_lf(x):
"""Assign a label based on the cosine similarity between the document vector and the class centroid"""
gen_dist = np.absolute(
spatial.distance.cosine(gen_vector_mean, x['spacy_vector'])
)
api_dist = np.absolute(
spatial.distance.cosine(api_vector_mean - x['spacy_vector'])
)
return GENERAL if gen_dist < api_dist else API
#
# WILL THIS WORK, BATMAN?
#
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment