Created
July 30, 2020 07:51
-
-
Save manmohan24nov/659326a2c446cdb80217ccbb0c7f9685 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import TFAutoModelForTokenClassification, AutoTokenizer | |
import tensorflow as tf | |
import praw | |
import pandas as pd | |
model = TFAutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english") | |
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") | |
label_list = [ | |
"O", # Outside of a named entity | |
"B-MISC", # Beginning of a miscellaneous entity right after another miscellaneous entity | |
"I-MISC", # Miscellaneous entity | |
"B-PER", # Beginning of a person's name right after another person's name | |
"I-PER", # Person's name | |
"B-ORG", # Beginning of an organisation right after another organisation | |
"I-ORG", # Organisation | |
"B-LOC", # Beginning of a location right after another location | |
"I-LOC" # Location | |
] | |
reddit = praw.Reddit(client_id='my_client_id', | |
client_secret='my_secret', | |
user_agent='my user agent') | |
def replies_of(top_level_comment, comment_list): | |
if len(top_level_comment.replies) == 0: | |
return | |
else: | |
for num, comment in enumerate(top_level_comment.replies): | |
try: | |
comment_list.append(str(comment.body)) | |
except: | |
continue | |
replies_of(comment, comment_list) | |
def main(): | |
count = 0 | |
master_dict = {'I-LOCX': [], 'I-ORGX': [], 'I-PERX': [], 'B-LOCX': [], 'B-ORGX': [], 'B-PERX': []} | |
word_temp = '' | |
current_tag = '' | |
old_tag = '' | |
print_dict = {} | |
list_of_subreddit = ['worldnews'] | |
for j in list_of_subreddit: | |
# get 10 hot posts from the MachineLearning subreddit | |
top_posts = reddit.subreddit(j).top('week', limit=1) | |
comment_list = [] | |
# save subreddit comments in dataframe | |
for submission in top_posts: | |
print('\n\n') | |
print("Title :" , submission.title) | |
submission_comm = reddit.submission(id=submission.id) | |
comment_list.append(str(submission.title)) | |
for count, top_level_comment in enumerate(submission_comm.comments): | |
try: | |
replies_of(top_level_comment, comment_list) | |
except: | |
continue | |
# print(comment_list) | |
# Bit of a hack to get the tokens with the special tokens | |
for sequence in comment_list: | |
if len(sequence) > 512: | |
continue | |
tokens = tokenizer.tokenize(tokenizer.decode(tokenizer.encode(sequence))) | |
inputs = tokenizer.encode(sequence, return_tensors="tf") | |
outputs = model(inputs)[0] | |
predictions = tf.argmax(outputs, axis=2) | |
list_bert = [(token, label_list[prediction]) for token, prediction in zip(tokens, predictions[0].numpy())] | |
for i in list_bert: | |
if i[1] in ['O', 'B-MISC', 'I-MISC']: | |
# print('in if') | |
if len(current_tag) > 0: | |
without_space_word = word_temp.strip() | |
if len(without_space_word) > 1: | |
master_dict[current_tag + 'X'].append(without_space_word) | |
count = 0 | |
word_temp = '' | |
current_tag = '' | |
continue | |
else: | |
current_tag = i[1] | |
if old_tag != current_tag and len(old_tag) > 0: | |
without_space_word = word_temp.strip() | |
if len(without_space_word) > 1: | |
master_dict[old_tag + 'X'].append(without_space_word) | |
count = 0 | |
word_temp = '' | |
current_tag = '' | |
if i[0].startswith('##'): | |
# print('in else if') | |
word_temp += i[0][2:].upper() | |
elif i[1] in ['I-PER', 'I-ORG', 'I-LOC', 'B-LOC', 'B-ORG', 'B-PER']: | |
# print('in end') | |
word_temp += " " + i[0].upper() | |
current_tag = i[1] | |
count += 1 | |
old_tag = current_tag | |
print(master_dict) | |
print_dict['Location'] = list(set(master_dict['I-LOCX'] + master_dict['B-LOCX'])) | |
print_dict['Organisation'] = list(set(master_dict['I-ORGX'] + master_dict['B-ORGX'])) | |
print_dict['Person Name'] = list(set(master_dict['I-PERX'] + master_dict['B-PERX'])) | |
print('\n\n\n') | |
print(print_dict) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment