Skip to content

Instantly share code, notes, and snippets.

@mhagiwara
Last active December 22, 2017 17:45
Show Gist options
  • Save mhagiwara/6354cfabbb849632381c5d2af3444d5f to your computer and use it in GitHub Desktop.
Save mhagiwara/6354cfabbb849632381c5d2af3444d5f to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import math
import random
import sys
from builtins import range
from collections import Counter
import numpy as np
import tensorflow as tf
def read_corpus(io):
"""Read the corpus from io (a file-like object, including stdin),
and return a concatenated list of words."""
corpus = []
for line in io:
# assume words are already tokenized by whitespaces
words = line.strip().split(' ')
corpus.extend(words)
return corpus
def build_dataset(corpus, vocabulary_size):
"""Convert corpus (list of words) to dataset (list of word indices).
Also build a dictionary, which is a look-up table from word to its index.
Replace OOVs (any words that do not appear often enough to be included in the vocabulary)
with a special token `UNK`."""
word_counts = Counter(corpus)
vocabulary_list = ['UNK'] + [w for w, _ in word_counts.most_common(vocabulary_size - 1)]
word_to_index = dict((w, i) for i, w in enumerate(vocabulary_list))
dataset = [word_to_index.get(w, 0) for w in corpus]
return dataset, vocabulary_list, word_to_index
def get_batch(dataset, offset, batch_size, skip_window, num_skips):
"""
Parameters:
dataset: dataset (list of ints).
offset: the offset of the input word (the center of the window)
batch_size: the number of instances in the returned batch.
skip_window: Number of word(s) to consider on the left and right
num_skips: Number of instance(s) to generate per window
"""
assert batch_size % num_skips == 0
assert num_skips <= 2 * skip_window
x_input = np.ndarray(shape=(batch_size), dtype=np.int32)
y_labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
k = 0
num_spans = batch_size // num_skips
for i in range(offset, offset+num_spans):
window = range(i - skip_window, i + skip_window + 1)
# HACK: assume UNK (index = 0) for out-of-bound indices
window_indices = [dataset[j] if 0 <= j < len(dataset) else 0
for j in window if i != j]
sampled_window_indices = random.sample(window_indices, num_skips)
for sampled_window_index in sampled_window_indices:
x_input[k] = dataset[i]
y_labels[k] = sampled_window_index
k += 1
return x_input, y_labels
class SkipGram(object):
def __init__(self, args):
self.train_inputs = tf.placeholder(tf.int32, shape=(args.batch_size))
self.train_labels = tf.placeholder(tf.int32, shape=(args.batch_size, 1))
self.embeddings_matrix = tf.Variable(
tf.random_uniform([args.vocabulary_size, args.embedding_size], -1.0, 1.0))
input_embed = tf.nn.embedding_lookup(self.embeddings_matrix, self.train_inputs)
nce_weights = tf.Variable(
tf.truncated_normal([args.vocabulary_size, args.embedding_size],
stddev=1.0 / math.sqrt(args.embedding_size)))
nce_biases = tf.Variable(tf.zeros([args.vocabulary_size]))
self.loss = tf.reduce_mean(
tf.nn.nce_loss(weights=nce_weights,
biases=nce_biases,
labels=self.train_labels,
inputs=input_embed,
num_sampled=args.num_negative_samples,
num_classes=args.vocabulary_size))
self.optimizer = tf.train.GradientDescentOptimizer(args.learning_rate).minimize(self.loss)
def write_vectors(embeddings_matrix, vocabulary_list):
for i, word in enumerate(vocabulary_list):
vec_str = ' '.join(['{:.3f}'.format(val) for val in embeddings_matrix[i]])
print('{:} {:}'.format(word, vec_str))
def main():
parser = argparse.ArgumentParser(description='Train the skip-gram model.')
parser.add_argument('--vocabulary_size', type=int, default=10000,
help='Size of the vocabulary (the number of unique tokens to consider)')
parser.add_argument('--skip_window', type=int, default=1,
help='Number of word(s) to consider on the left and right')
parser.add_argument('--num_skips', type=int, default=2,
help='Number of instance(s) to generate per window')
parser.add_argument('--embedding_size', type=int, default=128,
help='Size (dimension) of the word embedding vectors')
parser.add_argument('--batch_size', default=1024,
help='Size of the batch')
parser.add_argument('--num_negative_samples', default=64,
help='Number of negative samples to consider for NCE')
parser.add_argument('--epochs', default=5,
help='Number of epochs')
parser.add_argument('--learning_rate', default=1.0,
help='Learning rate')
args = parser.parse_args()
random.seed(31416)
print('Embedding size: {:d}'.format(args.embedding_size),
file=sys.stderr)
corpus = read_corpus(sys.stdin)
print('# of total words in the corpus: {:d}'.format(len(corpus)),
file=sys.stderr)
print('# of unique words in the corpus: {:d}'.format(len(set(corpus))),
file=sys.stderr)
dataset, vocabulary_list, _ = build_dataset(corpus, args.vocabulary_size)
print('Size of the vocabulary: {:d}'.format(len(vocabulary_list)),
file=sys.stderr)
model = SkipGram(args)
init = tf.global_variables_initializer()
with tf.Session() as session:
init.run()
num_batches = len(dataset) // args.batch_size
for _ in range(args.epochs):
for i in range(num_batches):
x_input, y_labels = get_batch(dataset, args.batch_size * i, args.batch_size,
args.skip_window, args.num_skips)
feed_dict = {model.train_inputs: x_input,
model.train_labels: y_labels}
_, loss_val = session.run([model.optimizer, model.loss], feed_dict=feed_dict)
if i % 1000 == 0:
print('Batch: {:d} / {:d}, loss = {:f}'.format(i, num_batches, loss_val),
file=sys.stderr)
# write embedding vectors
write_vectors(model.embeddings_matrix.eval(), vocabulary_list)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment