-
-
Save justindujardin/b7cceaf5c3eb99cb0bb02902a3d623a1 to your computer and use it in GitHub Desktop.
Prodigy recipe/loader for improving a trained model from the Twitter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf8 | |
from __future__ import unicode_literals, print_function | |
import cytoolz | |
import spacy | |
from prodigy.components.db import connect | |
from prodigy.components.sorters import prefer_uncertain, find_with_terms, prefer_high_scores | |
from prodigy.core import get_config | |
from prodigy.core import recipe, recipe_args | |
from prodigy.models.textcat import TextClassifier | |
from prodigy.util import get_seeds, get_seeds_from_set, log | |
from twitter_loader import twitter_search_loader, twitter_stream_loader | |
DB = connect() | |
@recipe('model.teach', | |
dataset=recipe_args['dataset'], | |
spacy_model=recipe_args['spacy_model'], | |
source=recipe_args['source'], | |
label=recipe_args['label'], | |
seeds=recipe_args['seeds'], | |
long_text=("Long text", "flag", "L", bool), | |
exclude=recipe_args['exclude'], | |
prefer=("Prefer type string [uncertain|high]", "flag", "P", str)) | |
def teach(dataset, spacy_model, source=None, label='', | |
seeds=None, long_text=False, exclude=None, prefer=False): | |
""" | |
Collect the best possible training data for a text classification model | |
with the model in the loop. Based on your annotations, Prodigy will decide | |
which questions to ask next. | |
""" | |
log('RECIPE: Starting recipe model.teach', locals()) | |
nlp = spacy.load(spacy_model) | |
log('RECIPE: Creating TextClassifier with model {}'.format(spacy_model)) | |
model = TextClassifier(nlp, label.split(','), long_text=long_text) | |
api_key = get_config()['twitter'] | |
if source is None: | |
stream = twitter_stream_loader(api_config=api_key) | |
else: | |
stream = twitter_search_loader(api_config=api_key, query=source) | |
if seeds is not None: | |
if isinstance(seeds, str) and seeds in DB: | |
seeds = get_seeds_from_set(seeds, DB.get_dataset(seeds)) | |
else: | |
seeds = get_seeds(seeds) | |
# Find 'seedy' examples | |
examples_with_seeds = list(find_with_terms(stream, seeds, | |
at_least=10, at_most=1000, | |
give_up_after=10000)) | |
for eg in examples_with_seeds: | |
eg.setdefault('meta', {}) | |
eg['meta']['via_seed'] = True | |
print("Found {} examples with seeds".format(len(examples_with_seeds))) | |
examples_with_seeds = [task for _, task in model(examples_with_seeds)] | |
# Rank the stream. Note this is continuous, as model() is a generator. | |
# As we call model.update(), the ranking of examples changes. | |
stream_preference = prefer_uncertain if prefer == (False or 'uncertain') else prefer_high_scores | |
stream = stream_preference(model(stream)) | |
# Prepend 'seedy' examples, if present | |
if seeds: | |
log("RECIPE: Prepending examples with seeds to the stream") | |
stream = cytoolz.concat((examples_with_seeds, stream)) | |
return { | |
'view_id': 'classification', | |
'dataset': dataset, | |
'stream': stream, | |
'exclude': exclude, | |
'update': model.update, | |
'config': {'lang': nlp.lang, 'labels': model.labels} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import re | |
from pprint import pprint | |
from twitter import * | |
from email.utils import parsedate | |
from time import strftime | |
def result_from_tweet(tweet, lang='en', retweets=False): | |
"""Extract an annotation result from a given search/streaming tweet object.""" | |
# Ignore stream delete actions | |
if "delete" in tweet: | |
return None | |
if "full_text" in tweet: | |
tweet_text = tweet["full_text"] | |
elif "extended_tweet" in tweet and "full_text" in tweet["extended_tweet"]: | |
tweet_text = tweet["extended_tweet"]["full_text"] | |
elif "text" in tweet: | |
tweet_text = tweet["text"] | |
else: | |
pprint(tweet) | |
print('^----- Ignored tweet with unknown text property') | |
return None | |
quoted = tweet_text.startswith('RT') | |
if retweets is False and (tweet["retweeted"] is True or quoted is True): | |
print('Skipped likely RT: {}'.format(tweet_text)) | |
return None | |
# Ignore tweets that do not match the desired language | |
if tweet["lang"] != lang: | |
return None | |
date_text = strftime("%B %d, %Y", parsedate(tweet["created_at"])) | |
user_text = tweet["user"]["screen_name"] | |
return { | |
"text": tweet_text, | |
"meta": { | |
"id": tweet["id"], | |
"user": user_text, | |
"date": date_text | |
} | |
} | |
def auth_for_tweet(api_config): | |
"""""" | |
return OAuth( | |
api_config["access_token"], | |
api_config["access_token_secret"], | |
api_config["consumer_key"], | |
api_config["consumer_secret"] | |
) | |
def twitter_stream_loader(api_config, | |
lang='en', | |
retweets=False): | |
"""Consume whatever real-time tweets twitter wants to give you""" | |
auth = auth_for_tweet(api_config) | |
api = TwitterStream(auth=auth, secure=True) | |
tweet_iter = api.statuses.sample() | |
for tweet in tweet_iter: | |
annotation = result_from_tweet(tweet, lang, retweets) | |
if annotation is not None: | |
yield annotation | |
def twitter_stream_track_loader(api_config, | |
query, | |
lang='en', | |
retweets=False): | |
"""Track a filtered real-time stream of tweets and return results for annotation""" | |
auth = auth_for_tweet(api_config) | |
api = TwitterStream(auth=auth, secure=True) | |
tweet_iter = api.statuses.filter(track=query) | |
for tweet in tweet_iter: | |
annotation = result_from_tweet(tweet, lang, retweets) | |
if annotation is not None: | |
yield annotation | |
def twitter_search_loader(api_config, query, lang='en', retweets=False): | |
"""Search twitter for a term and return results for annotation""" | |
auth = auth_for_tweet(api_config) | |
api = Twitter(auth=auth, secure=True) | |
max_id_extractor = re.compile("\?max_id=(\d+)") | |
max_id = None | |
while max_id is not 0: | |
if max_id is None: | |
result = api.search.tweets(q=query, lang=lang, tweet_mode='extended') | |
else: | |
result = api.search.tweets(q=query, lang=lang, max_id=max_id, tweet_mode='extended') | |
for tweet in result["statuses"]: | |
annotation = result_from_tweet(tweet, lang, retweets) | |
if annotation is not None: | |
yield annotation | |
# for some reason the next `max_id` value is only encoded in a prebuilt query string value. | |
if 'next_results' not in result['search_metadata']: | |
max_id = None | |
else: | |
max_id = max_id_extractor.split(result['search_metadata']['next_results'])[1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment