Measuring the similarity of books using TF-IDF, Doc2vec and TensorFlow
import collections | |
import math | |
import os | |
import pickle | |
import random | |
import re | |
import time | |
import urllib.request | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import tensorflow as tf | |
from sklearn.manifold import TSNE | |
TOP_WORDS = 2500 | |
EMBED_SIZE = 64 | |
BATCH_SIZE = 256 | |
EXAMPLE_SIZE = 3 | |
WINDOW_SIZE = 8 | |
DEMO_STEPS = 1000 | |
STATUS_STEPS = 100 | |
TOTAL_STEPS = 50000 | |
ENCODING_FILE = 'book_encodings.pkl' | |
DATA_FOLDER = 'e:\\doc2vec' | |
TF_FOLDER = 'logs' | |
BOOKS = [ | |
('https://www.gutenberg.org/files/46/46-0.txt', 'A Christmas Carol', 'Charles Dickens'), | |
('https://www.gutenberg.org/files/98/98-0.txt', 'A Tale of Two Cities', 'Charles Dickens'), | |
('https://www.gutenberg.org/files/11/11-0.txt', "Alice's Adventures in Wonderland", 'Lewis Carroll'), | |
('http://www.gutenberg.org/cache/epub/996/pg996.txt', 'Don Quixote', 'Miguel de Cervantes'), | |
('https://www.gutenberg.org/files/158/158-0.txt', 'Emma', 'Jane Austen'), | |
('http://www.gutenberg.org/cache/epub/83/pg83.txt', 'From the Earth to the Moon', 'Jules Verne'), | |
('https://www.gutenberg.org/files/1400/1400-0.txt', 'Great Expectations', 'Charles Dickens'), | |
('http://www.gutenberg.org/cache/epub/3748/pg3748.txt', 'Journey to the Center of the Earth', 'Jules Verne'), | |
('https://www.gutenberg.org/files/1342/1342-0.txt', 'Pride and Prejudice', 'Jane Austen'), | |
('http://www.gutenberg.org/cache/epub/21839/pg21839.txt', 'Sense and Sensibility', 'Jane Austen'), | |
('http://www.gutenberg.org/cache/epub/78/pg78.txt', 'Tarzan of the Apes', 'Edgar Rice Burroughs'), | |
('http://www.gutenberg.org/cache/epub/1013/pg1013.txt', 'The First Men In The Moon', 'H. G. Wells'), | |
('https://www.gutenberg.org/files/236/236-0.txt', 'The Jungle Book', 'Rudyard Kipling'), | |
('https://www.gutenberg.org/files/8147/8147-0.txt', 'The Man Who Would Be King', 'Rudyard Kipling'), | |
('https://www.gutenberg.org/files/35/35-0.txt', 'The Time Machine', 'H. G. Wells'), | |
('https://www.gutenberg.org/files/36/36-0.txt', 'The War of the Worlds', 'H. G. Wells'), | |
('http://www.gutenberg.org/cache/epub/43936/pg43936.txt', 'The Wonderful Wizard of Oz', 'L. Frank Baum'), | |
('https://www.gutenberg.org/files/12/12-0.txt', 'Through the Looking-Glass', 'Lewis Carroll'), | |
('https://www.gutenberg.org/files/120/120-0.txt', 'Treasure Island', 'Robert Louis Stevenson'), | |
('https://www.gutenberg.org/files/775/775-0.txt', 'When the Sleeper Wakes', 'H. G. Wells') | |
] | |
def books_containing(books, token): | |
return sum(1 for book in books if token in book) | |
def idf(books, token): | |
return math.log(len(books) / (1 + books_containing(books, token))) | |
def term_frequency(book, token): | |
return book[token] / len(book) | |
def tf_idf(books, book, token): | |
return term_frequency(book, token) * idf(books, token) | |
def download_books(): | |
path = os.path.join(DATA_FOLDER, 'books') | |
if not os.path.exists(path): | |
os.makedirs(path) | |
for url, _, _ in BOOKS: | |
name = url.rsplit('/', 1)[-1] | |
filename = os.path.join(path, name) | |
if not os.path.isfile(filename): | |
print('Downloading', url) | |
urllib.request.urlretrieve(url, filename) | |
def get_book_encodings(): | |
# Build authors dictionary. | |
_, _, authors = zip(*BOOKS) | |
author_set = set(authors) | |
authors = {} | |
for author in author_set: | |
authors[author] = len(authors) | |
# Load encodings from file if they exist. | |
data_file = os.path.join(DATA_FOLDER, ENCODING_FILE) | |
if os.path.isfile(data_file): | |
with open(data_file, 'rb') as f: | |
dump = pickle.load(f) | |
return dump[0], dump[1], authors | |
# Count tokens. | |
word_count = 0 | |
book_tokens = [] | |
book_counters = [] | |
reg_alpha = re.compile('[^a-z]') | |
reg_apostrophe = re.compile(r"['’]") | |
path = os.path.join(DATA_FOLDER, 'books') | |
for url, _, _ in BOOKS: | |
file = url.rsplit('/', 1)[-1] | |
with open(os.path.join(path, file), encoding='utf8') as book_file: | |
text = book_file.read() | |
tokens = reg_alpha.sub(' ', reg_apostrophe.sub('', text.lower())).split() | |
tokens = [t for t in tokens if len(t) > 1] | |
book_tokens.append(tokens) | |
book_counters.append(collections.Counter(tokens)) | |
word_count += len(tokens) | |
# Calculate TF-IDF scores. | |
vocab_set = set() | |
for book_ix, book in enumerate(book_counters): | |
print('Top tokens in:', BOOKS[book_ix][1]) | |
scores = {token: tf_idf(book_counters, book, token) for token in book} | |
sorted_tokens = sorted(scores.items(), key=lambda x: x[1], reverse=True) | |
tokens, _ = zip(*sorted_tokens[:TOP_WORDS]) | |
vocab_set = vocab_set.union(set(tokens)) | |
for token, score in sorted_tokens[:EXAMPLE_SIZE]: | |
print('>', token, round(score, 5)) | |
print() | |
# Build tokens dictionary. | |
vocab = {} | |
for token in vocab_set: | |
vocab[token] = len(vocab) | |
print('Word count:', word_count) | |
print('Vocab size:', len(vocab)) | |
print() | |
# Encode books. | |
book_encodings = [] | |
for book_ix, tokens in enumerate(book_tokens): | |
book_labels = [0] * len(BOOKS) | |
book_labels[book_ix] = 1 | |
encoding = [] | |
for token in tokens: | |
if token in vocab: | |
encoding.append(vocab[token]) | |
book_encodings.append([encoding, book_labels]) | |
with open(data_file, 'wb') as f: | |
dump = [book_encodings, len(vocab)] | |
pickle.dump(dump, f) | |
return book_encodings, len(vocab), authors | |
def generate_batch(encodings): | |
tokens_batch = [] | |
books_batch = [] | |
while len(tokens_batch) < BATCH_SIZE: | |
book_ix = random.randint(0, len(BOOKS) - 1) | |
encoding = encodings[book_ix] | |
token_ix = random.randint(0, len(encoding[0]) - WINDOW_SIZE) | |
inputs = encoding[0][token_ix:token_ix + WINDOW_SIZE] | |
tokens_batch.append(inputs) | |
books_batch.append(encoding[1]) | |
return tokens_batch, books_batch | |
def report_closest(book_embeddings, book_ix, session): | |
print('Closest to:', BOOKS[book_ix][1]) | |
norm = tf.sqrt(tf.reduce_sum(tf.square(book_embeddings), 1, keepdims=True)) | |
norm_book_embeddings = book_embeddings / norm | |
sim = session.run(tf.matmul(tf.expand_dims(norm_book_embeddings[book_ix], 0), norm_book_embeddings, transpose_b=True))[0] | |
nearest = (-sim).argsort()[1:EXAMPLE_SIZE + 1] | |
for k in range(EXAMPLE_SIZE): | |
print('> %0.3f %s' % (sim[nearest[k]], BOOKS[nearest[k]][1])) | |
print() | |
def main(): | |
# Setup environment. | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
if not os.path.exists(DATA_FOLDER): | |
os.makedirs(DATA_FOLDER) | |
# Download books. | |
download_books() | |
# Build vocab and training data. | |
book_encodings, vocab_size, authors = get_book_encodings() | |
# Token embeddings. | |
inputs = tf.placeholder(tf.int32, shape=[None, WINDOW_SIZE]) | |
token_embeddings = tf.Variable(tf.random_uniform([vocab_size, EMBED_SIZE], -1.0, 1.0)) | |
token_embedding = tf.zeros([BATCH_SIZE, EMBED_SIZE]) | |
for i in range(WINDOW_SIZE): | |
token_embedding += tf.nn.embedding_lookup(token_embeddings, inputs[:, i]) | |
# Book embeddings. | |
book_labels = tf.placeholder(tf.int32, shape=[None, len(BOOKS)]) | |
book_bias = tf.Variable(tf.constant(0.1, shape=[len(BOOKS)])) | |
book_embeddings = tf.Variable(tf.random_uniform([EMBED_SIZE, len(BOOKS)], -1.0, 1.0)) | |
book_outputs = tf.matmul(token_embedding, book_embeddings) + book_bias | |
# Calculate loss and train. | |
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=book_outputs, labels=book_labels)) | |
tf.summary.scalar('loss', loss) | |
optimizer = tf.train.AdamOptimizer() | |
global_step = tf.Variable(0, trainable=False, name='global_step') | |
train_op = optimizer.minimize(loss=loss, global_step=global_step) | |
saver = tf.train.Saver(max_to_keep=1) | |
# Train model. | |
with tf.Session() as session: | |
# Initialise summary file writer. | |
merged = tf.summary.merge_all() | |
summary_writer = tf.summary.FileWriter(os.path.join(DATA_FOLDER, TF_FOLDER), session.graph) | |
# Load or initialise model. | |
checkpoint = tf.train.latest_checkpoint(os.path.join(DATA_FOLDER, TF_FOLDER)) | |
if checkpoint: | |
print('Loading checkpoint') | |
saver.restore(session, checkpoint) | |
else: | |
session.run(tf.global_variables_initializer()) | |
session.run(tf.local_variables_initializer()) | |
step = 0 | |
start_time = time.time() | |
while step < TOTAL_STEPS: | |
tokens_batch, books_batch = generate_batch(book_encodings) | |
feed_dict = {inputs: tokens_batch, book_labels: books_batch} | |
_, step = session.run([train_op, global_step], feed_dict=feed_dict) | |
# Update status. | |
if step % STATUS_STEPS == 0: | |
current_time = time.time() | |
elapsed_time = current_time - start_time | |
time_left = TOTAL_STEPS * elapsed_time / step - elapsed_time | |
ls, summary = session.run([loss, merged], feed_dict=feed_dict) | |
print('step %d, loss %0.3f, remaining %0.2f' % (step, ls, time_left / 60)) | |
summary_writer.add_summary(summary, step) | |
# Run demo tasks. | |
if step % DEMO_STEPS == 0: | |
book_ix = int(np.random.choice(len(BOOKS), size=1)) | |
embeddings = session.run(tf.transpose(book_embeddings)) | |
report_closest(embeddings, book_ix, session, ) | |
print('Saving checkpoint\n') | |
saver.save(session, os.path.join(DATA_FOLDER, TF_FOLDER, 'doc2vec.ckpt'), global_step=step) | |
# Book report. | |
embeddings = session.run(tf.transpose(book_embeddings)) | |
for book_ix in range(len(BOOKS)): | |
report_closest(embeddings, book_ix, session) | |
# Project book embeddings onto 2D plane. | |
embeddings = session.run(tf.transpose(book_embeddings)) | |
x, y = zip(*TSNE(n_components=2, verbose=1, perplexity=len(BOOKS) / len(authors), n_iter=5000).fit_transform(embeddings)) | |
colours = [authors[BOOKS[book_ix][2]] for book_ix in range(len(BOOKS))] | |
fig, ax = plt.subplots() | |
plt.scatter(x, y, s=120 ** 2, alpha=0.6, cmap='brg', c=colours) | |
for book_ix in range(len(BOOKS)): | |
ax.annotate(BOOKS[book_ix][1] + '\n' + BOOKS[book_ix][2], (x[book_ix], y[book_ix]), horizontalalignment='center', verticalalignment='center', fontsize=10) | |
plt.show() | |
print() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment