Created
December 5, 2017 19:36
-
-
Save duarteocarmo/395b090fef8c46a51f2f761542caee04 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" Use DeepMoji to score texts for emoji distribution. | |
The resulting emoji ids (0-63) correspond to the mapping | |
in emoji_overview.png file at the root of the DeepMoji repo. | |
Writes the result to a csv file. | |
""" | |
from __future__ import print_function, division | |
import json | |
import numpy as np | |
from deepmoji.sentence_tokenizer import SentenceTokenizer | |
from deepmoji.model_def import deepmoji_emojis | |
from deepmoji.global_variables import PRETRAINED_PATH, VOCAB_PATH | |
from pymongo import MongoClient | |
import numpy as np | |
OUTPUT_PATH = 'test_sentences.csv' | |
TEST_SENTENCES = [u'I love mom\'s cooking', | |
u'I love how you never reply back..', | |
u'I love cruising with my homies', | |
u'I love messing with yo mind!!', | |
u'I love you and now you\'re just gone..', | |
u'This is shit', | |
u'This is the shit'] | |
# intialize mongo client on MongoDB Atlas | |
client = MongoClient("mongodb://socialgraphs:interactions@socialgraphs-shard-00-00-al7cj.mongodb.net:27017,socialgraphs-shard-00-01-al7cj.mongodb.net:27017,socialgraphs-shard-00-02-al7cj.mongodb.net:27017/test?ssl=true&replicaSet=SocialGraphs-shard-0&authSource=admin") | |
db = client.texas | |
# access tweet collection | |
tweet_collection = db.tweetHistory | |
tweets = list(tweet_collection.find({},{'id': 1, 'text': 1})) | |
tweets_text = map(lambda s: s['text'], tweets) | |
def top_elements(array, k): | |
ind = np.argpartition(array, -k)[-k:] | |
return ind[np.argsort(array[ind])][::-1] | |
maxlen = 30 | |
batch_size = 32 | |
print('Tokenizing using dictionary from {}'.format(VOCAB_PATH)) | |
with open(VOCAB_PATH, 'r') as f: | |
vocabulary = json.load(f) | |
st = SentenceTokenizer(vocabulary, maxlen) | |
tokenized, _, _ = st.tokenize_sentences(tweets_text) | |
print('Loading model from {}.'.format(PRETRAINED_PATH)) | |
model = deepmoji_emojis(maxlen, PRETRAINED_PATH) | |
model.summary() | |
print('Running predictions.') | |
prob = model.predict(tokenized) | |
# Find top emojis for each sentence. Emoji ids (0-63) | |
# correspond to the mapping in emoji_overview.png | |
# at the root of the DeepMoji repo. | |
print('Writing results to {}'.format(OUTPUT_PATH)) | |
fields = ['Top5%', 'Emoji_1', 'Emoji_2', 'Emoji_3', 'Emoji_4', 'Emoji_5', 'Pct_1', 'Pct_2', 'Pct_3', 'Pct_4', 'Pct_5'] | |
for i, t in enumerate(tweets): | |
print(i) | |
t_tokens = tokenized[i] | |
t_score = [] | |
t_prob = prob[i] | |
ind_top = top_elements(t_prob, 5) | |
t_score.append(sum(t_prob[ind_top])) | |
t_score.extend(ind_top) | |
t_score.extend([str(t_prob[ind]) for ind in ind_top]) | |
dictionary = dict(zip(fields, t_score)) | |
tweet_collection.update({"id": t["id"]}, {"$set": {"deepmoji": dict(zip(fields, t_score))}}) | |
print('Done.') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment