Skip to content

Instantly share code, notes, and snippets.

@ravishchawla
Created September 12, 2017 16:07
Show Gist options
  • Save ravishchawla/91994122e1820e976daa41c7aa8f4998 to your computer and use it in GitHub Desktop.
Save ravishchawla/91994122e1820e976daa41c7aa8f4998 to your computer and use it in GitHub Desktop.
# coding: utf-8
# # Training a Word2Vec Model on the Reddit Comments Dataset
#
# ### Ravish Chawla
# In[276]:
get_ipython().magic('matplotlib inline')
import nltk.data;
from gensim.models import word2vec;
from sklearn.cluster import KMeans;
from sklearn.neighbors import KDTree;
import pandas as pd;
import numpy as np;
import os;
import re;
import logging;
import sqlite3;
import time;
import sys;
import multiprocessing;
from wordcloud import WordCloud, ImageColorGenerator
import matplotlib.pyplot as plt;
from itertools import cycle;
# In[3]:
nltk.download('punkt')
# In[9]:
sql_con = sqlite3.connect('/mnt/big/data/database.sqlite')
# In[11]:
start = time.time()
sql_data = pd.read_sql("SELECT body FROM May2015", sql_con);
print('Total time: ' + str((time.time() - start)) + ' secs')
# In[53]:
total_rows = len(sql_data);
print(total_rows)
# In[4]:
tokenizer = nltk.data.load('tokenizers/punkt/english.pickle');
# In[29]:
def clean_text(all_comments, out_name):
out_file = open(out_name, 'w');
for pos in range(len(all_comments)):
#Get the comment
val = all_comments.iloc[pos]['body'];
#Normalize tabs and remove newlines
no_tabs = str(val).replace('\t', ' ').replace('\n', '');
#Remove all characters except A-Z and a dot.
alphas_only = re.sub("[^a-zA-Z\.]", " ", no_tabs);
#Normalize spaces to 1
multi_spaces = re.sub(" +", " ", alphas_only);
#Strip trailing and leading spaces
no_spaces = multi_spaces.strip();
#Normalize all charachters to lowercase
clean_text = no_spaces.lower();
#Get sentences from the tokenizer, remove the dot in each.
sentences = tokenizer.tokenize(clean_text);
sentences = [re.sub("[\.]", "", sentence) for sentence in sentences];
#If the text has more than one space (removing single word comments) and one character, write it to the file.
if len(clean_text) > 0 and clean_text.count(' ') > 0:
for sentence in sentences:
out_file.write("%s\n" % sentence)
print(sentence);
#Simple logging. At every 50000th step,
#print the total number of rows processed and time taken so far, and flush the file.
if pos % 50000 == 0:
total_time = time.time() - start;
sys.stdout.write('Completed ' + str(round(100 * (pos / total_rows), 2)) + '% - ' + str(pos) + ' rows in time ' + str(round(total_time / 60, 0)) + ' min & ' + str(round(total_time % 60, 2)) + ' secs\r');
out_file.flush();
break;
out_file.close();
# In[104]:
start = time.time();
clean_comments = clean_text(sql_data, '/mnt/big/out_full')
print('Total time: ' + str((time.time() - start)) + ' secs')
# In[ ]:
start = time.time();
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
num_features = 100; # Dimensionality of the hidden layer representation
min_word_count = 40; # Minimum word count to keep a word in the vocabulary
num_workers = multiprocessing.cpu_count(); # Number of threads to run in parallel set to total number of cpus.
context = 5 # Context window size (on each side)
downsampling = 1e-3 # Downsample setting for frequent words
# Initialize and train the model.
print("Training model...");
model = word2vec.Word2Vec(LineSentence('/mnt/big/out_full_clean'), workers=num_workers, size=num_features, min_count = min_word_count, window = context, sample = downsampling);
model.init_sims(replace=True);
# Save the model
model_name = "model_full_reddit";
model.save(model_name);
print('Total time: ' + str((time.time() - start)) + ' secs')
# In[33]:
model = word2vec.Word2Vec.load('model_full_reddit');
# In[34]:
Z = model.wv.syn0;
# In[37]:
print(Z[0].shape)
print(Z[0])
# In[60]:
def clustering_on_wordvecs(word_vectors, num_clusters):
# Initalize a k-means object and use it to extract centroids
kmeans_clustering = KMeans(n_clusters = num_clusters, init='k-means++');
idx = kmeans_clustering.fit_predict(word_vectors);
return kmeans_clustering.cluster_centers_, idx;
# In[61]:
start = time.time();
centers, clusters = clustering_on_wordvecs(Z, 50);
print('Total time: ' + str((time.time() - start)) + ' secs')
# In[62]:
start = time.time();
centroid_map = dict(zip(model.wv.index2word, clusters));
print('Total time: ' + str((time.time() - start)) + ' secs')
# In[120]:
def get_top_words(index2word, k, centers, wordvecs):
tree = KDTree(wordvecs);
#Closest points for each Cluster center is used to query the closest 20 points to it.
closest_points = [tree.query(np.reshape(x, (1, -1)), k=k) for x in centers];
closest_words_idxs = [x[1] for x in closest_points];
#Word Index is queried for each position in the above array, and added to a Dictionary.
closest_words = {};
for i in range(0, len(closest_words_idxs)):
closest_words['Cluster #' + str(i+1).zfill(2)] = [index2word[j] for j in closest_words_idxs[i][0]]
#A DataFrame is generated from the dictionary.
df = pd.DataFrame(closest_words);
df.index = df.index+1
return df;
# In[121]:
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
# In[292]:
top_words = get_top_words(model.wv.index2word, 20, centers, Z);
# In[293]:
top_words
# In[298]:
def display_cloud(cluster_num, cmap):
wc = WordCloud(background_color="black", max_words=2000, max_font_size=80, colormap=cmap);
wordcloud = wc.generate(' '.join([word for word in top_words['Cluster #' + str(cluster_num).zfill(2)]]))
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis("off")
plt.savefig('cluster_' + str(cluster_num), bbox_inches='tight')
# In[303]:
cmaps = cycle([
'flag', 'prism', 'ocean', 'gist_earth', 'terrain', 'gist_stern',
'gnuplot', 'gnuplot2', 'CMRmap', 'cubehelix', 'brg', 'hsv',
'gist_rainbow', 'rainbow', 'jet', 'nipy_spectral', 'gist_ncar'])
for i in range(50):
col = next(cmaps);
display_cloud(i+1, col)
# In[361]:
def get_word_table(table, key, sim_key='similarity', show_sim = True):
if show_sim == True:
return pd.DataFrame(table, columns=[key, sim_key])
else:
return pd.DataFrame(table, columns=[key, sim_key])[key]
# In[344]:
get_word_table(model.wv.most_similar_cosmul(positive=['king', 'woman'], negative=['queen']), 'Analogy')
# In[82]:
model.wv.doesnt_match("apple microsoft samsung tesla".split())
# In[83]:
model.wv.doesnt_match("trump clinton sanders obama".split())
# In[87]:
model.wv.doesnt_match("joffrey cersei tywin lannister jon".split())
# In[320]:
model.wv.doesnt_match("daenerys rhaegar viserion aemon aegon jon targaryen".split())
# In[442]:
keys = ['musk', 'modi', 'hodor', 'martell', 'apple', 'neutrality', 'snowden', 'batman', 'hulk', 'warriors', 'falcons', 'pizza', ];
tables = [];
for key in keys:
tables.append(get_word_table(model.wv.similar_by_word(key), key, show_sim=False))
# In[443]:
pd.concat(tables, axis=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment