Created
September 12, 2017 16:07
-
-
Save ravishchawla/91994122e1820e976daa41c7aa8f4998 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# # Training a Word2Vec Model on the Reddit Comments Dataset | |
# | |
# ### Ravish Chawla | |
# In[276]: | |
get_ipython().magic('matplotlib inline') | |
import nltk.data; | |
from gensim.models import word2vec; | |
from sklearn.cluster import KMeans; | |
from sklearn.neighbors import KDTree; | |
import pandas as pd; | |
import numpy as np; | |
import os; | |
import re; | |
import logging; | |
import sqlite3; | |
import time; | |
import sys; | |
import multiprocessing; | |
from wordcloud import WordCloud, ImageColorGenerator | |
import matplotlib.pyplot as plt; | |
from itertools import cycle; | |
# In[3]: | |
nltk.download('punkt') | |
# In[9]: | |
sql_con = sqlite3.connect('/mnt/big/data/database.sqlite') | |
# In[11]: | |
start = time.time() | |
sql_data = pd.read_sql("SELECT body FROM May2015", sql_con); | |
print('Total time: ' + str((time.time() - start)) + ' secs') | |
# In[53]: | |
total_rows = len(sql_data); | |
print(total_rows) | |
# In[4]: | |
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle'); | |
# In[29]: | |
def clean_text(all_comments, out_name): | |
out_file = open(out_name, 'w'); | |
for pos in range(len(all_comments)): | |
#Get the comment | |
val = all_comments.iloc[pos]['body']; | |
#Normalize tabs and remove newlines | |
no_tabs = str(val).replace('\t', ' ').replace('\n', ''); | |
#Remove all characters except A-Z and a dot. | |
alphas_only = re.sub("[^a-zA-Z\.]", " ", no_tabs); | |
#Normalize spaces to 1 | |
multi_spaces = re.sub(" +", " ", alphas_only); | |
#Strip trailing and leading spaces | |
no_spaces = multi_spaces.strip(); | |
#Normalize all charachters to lowercase | |
clean_text = no_spaces.lower(); | |
#Get sentences from the tokenizer, remove the dot in each. | |
sentences = tokenizer.tokenize(clean_text); | |
sentences = [re.sub("[\.]", "", sentence) for sentence in sentences]; | |
#If the text has more than one space (removing single word comments) and one character, write it to the file. | |
if len(clean_text) > 0 and clean_text.count(' ') > 0: | |
for sentence in sentences: | |
out_file.write("%s\n" % sentence) | |
print(sentence); | |
#Simple logging. At every 50000th step, | |
#print the total number of rows processed and time taken so far, and flush the file. | |
if pos % 50000 == 0: | |
total_time = time.time() - start; | |
sys.stdout.write('Completed ' + str(round(100 * (pos / total_rows), 2)) + '% - ' + str(pos) + ' rows in time ' + str(round(total_time / 60, 0)) + ' min & ' + str(round(total_time % 60, 2)) + ' secs\r'); | |
out_file.flush(); | |
break; | |
out_file.close(); | |
# In[104]: | |
start = time.time(); | |
clean_comments = clean_text(sql_data, '/mnt/big/out_full') | |
print('Total time: ' + str((time.time() - start)) + ' secs') | |
# In[ ]: | |
start = time.time(); | |
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) | |
num_features = 100; # Dimensionality of the hidden layer representation | |
min_word_count = 40; # Minimum word count to keep a word in the vocabulary | |
num_workers = multiprocessing.cpu_count(); # Number of threads to run in parallel set to total number of cpus. | |
context = 5 # Context window size (on each side) | |
downsampling = 1e-3 # Downsample setting for frequent words | |
# Initialize and train the model. | |
print("Training model..."); | |
model = word2vec.Word2Vec(LineSentence('/mnt/big/out_full_clean'), workers=num_workers, size=num_features, min_count = min_word_count, window = context, sample = downsampling); | |
model.init_sims(replace=True); | |
# Save the model | |
model_name = "model_full_reddit"; | |
model.save(model_name); | |
print('Total time: ' + str((time.time() - start)) + ' secs') | |
# In[33]: | |
model = word2vec.Word2Vec.load('model_full_reddit'); | |
# In[34]: | |
Z = model.wv.syn0; | |
# In[37]: | |
print(Z[0].shape) | |
print(Z[0]) | |
# In[60]: | |
def clustering_on_wordvecs(word_vectors, num_clusters): | |
# Initalize a k-means object and use it to extract centroids | |
kmeans_clustering = KMeans(n_clusters = num_clusters, init='k-means++'); | |
idx = kmeans_clustering.fit_predict(word_vectors); | |
return kmeans_clustering.cluster_centers_, idx; | |
# In[61]: | |
start = time.time(); | |
centers, clusters = clustering_on_wordvecs(Z, 50); | |
print('Total time: ' + str((time.time() - start)) + ' secs') | |
# In[62]: | |
start = time.time(); | |
centroid_map = dict(zip(model.wv.index2word, clusters)); | |
print('Total time: ' + str((time.time() - start)) + ' secs') | |
# In[120]: | |
def get_top_words(index2word, k, centers, wordvecs): | |
tree = KDTree(wordvecs); | |
#Closest points for each Cluster center is used to query the closest 20 points to it. | |
closest_points = [tree.query(np.reshape(x, (1, -1)), k=k) for x in centers]; | |
closest_words_idxs = [x[1] for x in closest_points]; | |
#Word Index is queried for each position in the above array, and added to a Dictionary. | |
closest_words = {}; | |
for i in range(0, len(closest_words_idxs)): | |
closest_words['Cluster #' + str(i+1).zfill(2)] = [index2word[j] for j in closest_words_idxs[i][0]] | |
#A DataFrame is generated from the dictionary. | |
df = pd.DataFrame(closest_words); | |
df.index = df.index+1 | |
return df; | |
# In[121]: | |
pd.set_option('display.max_columns', None) | |
pd.set_option('display.max_rows', None) | |
# In[292]: | |
top_words = get_top_words(model.wv.index2word, 20, centers, Z); | |
# In[293]: | |
top_words | |
# In[298]: | |
def display_cloud(cluster_num, cmap): | |
wc = WordCloud(background_color="black", max_words=2000, max_font_size=80, colormap=cmap); | |
wordcloud = wc.generate(' '.join([word for word in top_words['Cluster #' + str(cluster_num).zfill(2)]])) | |
plt.imshow(wordcloud, interpolation='bilinear') | |
plt.axis("off") | |
plt.savefig('cluster_' + str(cluster_num), bbox_inches='tight') | |
# In[303]: | |
cmaps = cycle([ | |
'flag', 'prism', 'ocean', 'gist_earth', 'terrain', 'gist_stern', | |
'gnuplot', 'gnuplot2', 'CMRmap', 'cubehelix', 'brg', 'hsv', | |
'gist_rainbow', 'rainbow', 'jet', 'nipy_spectral', 'gist_ncar']) | |
for i in range(50): | |
col = next(cmaps); | |
display_cloud(i+1, col) | |
# In[361]: | |
def get_word_table(table, key, sim_key='similarity', show_sim = True): | |
if show_sim == True: | |
return pd.DataFrame(table, columns=[key, sim_key]) | |
else: | |
return pd.DataFrame(table, columns=[key, sim_key])[key] | |
# In[344]: | |
get_word_table(model.wv.most_similar_cosmul(positive=['king', 'woman'], negative=['queen']), 'Analogy') | |
# In[82]: | |
model.wv.doesnt_match("apple microsoft samsung tesla".split()) | |
# In[83]: | |
model.wv.doesnt_match("trump clinton sanders obama".split()) | |
# In[87]: | |
model.wv.doesnt_match("joffrey cersei tywin lannister jon".split()) | |
# In[320]: | |
model.wv.doesnt_match("daenerys rhaegar viserion aemon aegon jon targaryen".split()) | |
# In[442]: | |
keys = ['musk', 'modi', 'hodor', 'martell', 'apple', 'neutrality', 'snowden', 'batman', 'hulk', 'warriors', 'falcons', 'pizza', ]; | |
tables = []; | |
for key in keys: | |
tables.append(get_word_table(model.wv.similar_by_word(key), key, show_sim=False)) | |
# In[443]: | |
pd.concat(tables, axis=1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment