Last active
December 22, 2017 17:45
-
-
Save mhagiwara/6354cfabbb849632381c5d2af3444d5f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import argparse | |
import math | |
import random | |
import sys | |
from builtins import range | |
from collections import Counter | |
import numpy as np | |
import tensorflow as tf | |
def read_corpus(io): | |
"""Read the corpus from io (a file-like object, including stdin), | |
and return a concatenated list of words.""" | |
corpus = [] | |
for line in io: | |
# assume words are already tokenized by whitespaces | |
words = line.strip().split(' ') | |
corpus.extend(words) | |
return corpus | |
def build_dataset(corpus, vocabulary_size): | |
"""Convert corpus (list of words) to dataset (list of word indices). | |
Also build a dictionary, which is a look-up table from word to its index. | |
Replace OOVs (any words that do not appear often enough to be included in the vocabulary) | |
with a special token `UNK`.""" | |
word_counts = Counter(corpus) | |
vocabulary_list = ['UNK'] + [w for w, _ in word_counts.most_common(vocabulary_size - 1)] | |
word_to_index = dict((w, i) for i, w in enumerate(vocabulary_list)) | |
dataset = [word_to_index.get(w, 0) for w in corpus] | |
return dataset, vocabulary_list, word_to_index | |
def get_batch(dataset, offset, batch_size, skip_window, num_skips): | |
""" | |
Parameters: | |
dataset: dataset (list of ints). | |
offset: the offset of the input word (the center of the window) | |
batch_size: the number of instances in the returned batch. | |
skip_window: Number of word(s) to consider on the left and right | |
num_skips: Number of instance(s) to generate per window | |
""" | |
assert batch_size % num_skips == 0 | |
assert num_skips <= 2 * skip_window | |
x_input = np.ndarray(shape=(batch_size), dtype=np.int32) | |
y_labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32) | |
k = 0 | |
num_spans = batch_size // num_skips | |
for i in range(offset, offset+num_spans): | |
window = range(i - skip_window, i + skip_window + 1) | |
# HACK: assume UNK (index = 0) for out-of-bound indices | |
window_indices = [dataset[j] if 0 <= j < len(dataset) else 0 | |
for j in window if i != j] | |
sampled_window_indices = random.sample(window_indices, num_skips) | |
for sampled_window_index in sampled_window_indices: | |
x_input[k] = dataset[i] | |
y_labels[k] = sampled_window_index | |
k += 1 | |
return x_input, y_labels | |
class SkipGram(object): | |
def __init__(self, args): | |
self.train_inputs = tf.placeholder(tf.int32, shape=(args.batch_size)) | |
self.train_labels = tf.placeholder(tf.int32, shape=(args.batch_size, 1)) | |
self.embeddings_matrix = tf.Variable( | |
tf.random_uniform([args.vocabulary_size, args.embedding_size], -1.0, 1.0)) | |
input_embed = tf.nn.embedding_lookup(self.embeddings_matrix, self.train_inputs) | |
nce_weights = tf.Variable( | |
tf.truncated_normal([args.vocabulary_size, args.embedding_size], | |
stddev=1.0 / math.sqrt(args.embedding_size))) | |
nce_biases = tf.Variable(tf.zeros([args.vocabulary_size])) | |
self.loss = tf.reduce_mean( | |
tf.nn.nce_loss(weights=nce_weights, | |
biases=nce_biases, | |
labels=self.train_labels, | |
inputs=input_embed, | |
num_sampled=args.num_negative_samples, | |
num_classes=args.vocabulary_size)) | |
self.optimizer = tf.train.GradientDescentOptimizer(args.learning_rate).minimize(self.loss) | |
def write_vectors(embeddings_matrix, vocabulary_list): | |
for i, word in enumerate(vocabulary_list): | |
vec_str = ' '.join(['{:.3f}'.format(val) for val in embeddings_matrix[i]]) | |
print('{:} {:}'.format(word, vec_str)) | |
def main(): | |
parser = argparse.ArgumentParser(description='Train the skip-gram model.') | |
parser.add_argument('--vocabulary_size', type=int, default=10000, | |
help='Size of the vocabulary (the number of unique tokens to consider)') | |
parser.add_argument('--skip_window', type=int, default=1, | |
help='Number of word(s) to consider on the left and right') | |
parser.add_argument('--num_skips', type=int, default=2, | |
help='Number of instance(s) to generate per window') | |
parser.add_argument('--embedding_size', type=int, default=128, | |
help='Size (dimension) of the word embedding vectors') | |
parser.add_argument('--batch_size', default=1024, | |
help='Size of the batch') | |
parser.add_argument('--num_negative_samples', default=64, | |
help='Number of negative samples to consider for NCE') | |
parser.add_argument('--epochs', default=5, | |
help='Number of epochs') | |
parser.add_argument('--learning_rate', default=1.0, | |
help='Learning rate') | |
args = parser.parse_args() | |
random.seed(31416) | |
print('Embedding size: {:d}'.format(args.embedding_size), | |
file=sys.stderr) | |
corpus = read_corpus(sys.stdin) | |
print('# of total words in the corpus: {:d}'.format(len(corpus)), | |
file=sys.stderr) | |
print('# of unique words in the corpus: {:d}'.format(len(set(corpus))), | |
file=sys.stderr) | |
dataset, vocabulary_list, _ = build_dataset(corpus, args.vocabulary_size) | |
print('Size of the vocabulary: {:d}'.format(len(vocabulary_list)), | |
file=sys.stderr) | |
model = SkipGram(args) | |
init = tf.global_variables_initializer() | |
with tf.Session() as session: | |
init.run() | |
num_batches = len(dataset) // args.batch_size | |
for _ in range(args.epochs): | |
for i in range(num_batches): | |
x_input, y_labels = get_batch(dataset, args.batch_size * i, args.batch_size, | |
args.skip_window, args.num_skips) | |
feed_dict = {model.train_inputs: x_input, | |
model.train_labels: y_labels} | |
_, loss_val = session.run([model.optimizer, model.loss], feed_dict=feed_dict) | |
if i % 1000 == 0: | |
print('Batch: {:d} / {:d}, loss = {:f}'.format(i, num_batches, loss_val), | |
file=sys.stderr) | |
# write embedding vectors | |
write_vectors(model.embeddings_matrix.eval(), vocabulary_list) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment