Created
November 11, 2018 03:36
-
-
Save jamesonthecrow/e20fac6cb1422d6ed18c0a6932d244b2 to your computer and use it in GitHub Desktop.
Test the subreddit suggester on the newest 100 posts in each subreddit.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import coremltools | |
# Scrape the 100 newest posts from each subreddit | |
max_posts = 100 | |
posts = [] | |
for subreddit in subreddits: | |
posts.extend(get_n_posts(max_posts, subreddit, sort='new')) | |
# Apply the same preprocessing to the data | |
new_df = pandas.DataFrame(posts, columns=['subreddit', 'title']) | |
new_df.title = new_df.title.apply(lambda x: re.sub(r'\[.*\]', '', x)) | |
new_df.title = new_df.title.apply(lambda x: re.sub(r'\W(?<![ ])', '', x)) | |
# Load the mlmodel with coremltools for Python. Note you need to be | |
# on macOS for the predict function to work. | |
mlmodel = coremltools.models.MLModel('PATH/TO/subredditClassifier.mlmodel') | |
# Predict a subreddit for each title. | |
new_df['predicted'] = new_df.title.apply(lambda x: mlmodel.predict({'text': x})['label']) | |
# Mark the model correct if the predicted matches the actual subreddit | |
new_df['correct'] = new_df.predicted == new_df.subreddit | |
# Compute the fraction correct. | |
new_df.correct.sum() / new_df.shape[0] | |
# Output: 0.55 | |
# We got 55% correct. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment