Skip to content

Instantly share code, notes, and snippets.

Last active August 3, 2017 14:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nojima/7f286e53d35cde0a09ad24e519513492 to your computer and use it in GitHub Desktop.
Save nojima/7f286e53d35cde0a09ad24e519513492 to your computer and use it in GitHub Desktop.
from logging import getLogger, StreamHandler, DEBUG
from typing import Tuple, Iterator, Dict
import chainer.functions as F
import chainer.links as L
import numpy as np
from chainer import Variable, optimizers, serializers, Chain
from chainer.utils import walker_alias
from scipy.spatial.distance import cosine
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
logger = getLogger(__name__)
handler = StreamHandler()
logger.propagate = False
class Vocabulary:
def __init__(self):
self._word2id = {} # type: Dict[str, int]
self._id2word = {} # type: Dict[int, str]
def intern(self, word: str) -> int:
if word not in self._word2id:
id = len(self._word2id)
self._word2id[word] = id
self._id2word[id] = word
return self._word2id[word]
def to_word(self, id: int) -> str:
return self._id2word[id]
def to_id(self, word: str) -> int:
return self._word2id[word]
def size(self):
return len(self._word2id)
class DataSet:
def __init__(self, filename: str, vocabulary: Vocabulary = None) -> None:
self._vocabulary = vocabulary or Vocabulary()
data = []
with open(filename) as f:
for line in f:
for word in line.split():
id = self._vocabulary.intern(word)
self._data = np.array(data, dtype=np.int32)
def size(self) -> int:
return len(self._data)
def vocabulary(self) -> Vocabulary:
return self._vocabulary
def data(self) -> np.ndarray:
return self._data
def make_sampler(self) -> walker_alias.WalkerAlias:
_, counts = np.unique(self._data, return_counts=True)
counts = np.power(counts, 0.75)
return walker_alias.WalkerAlias(counts)
class Word2Vec(Chain):
def __init__(self, n_vocabulary: int, n_units: int) -> None:
with self.init_scope():
self._embed_input = L.EmbedID(n_vocabulary, n_units)
self._embed_output = L.EmbedID(n_vocabulary, n_units)
def __call__(self, x1: Variable, x2: Variable, t: Variable) -> Variable:
output = self.forward(x1, x2)
return F.sigmoid_cross_entropy(output, t)
def forward(self, x1: Variable, x2: Variable) -> Variable:
h1 = self._embed_input(x1)
h2 = self._embed_output(x2)
return F.sum(h1 * h2, axis=1)
def distributed_representation(self, word_id: np.ndarray) -> np.ndarray:
return self._embed_input(Variable(word_id)).data
class Word2VecOneW(Chain):
def __init__(self, n_vocabulary: int, n_units: int) -> None:
with self.init_scope():
self._embed = L.EmbedID(n_vocabulary, n_units)
def __call__(self, x1: Variable, x2: Variable, t: Variable) -> Variable:
output = self.forward(x1, x2)
return F.sigmoid_cross_entropy(output, t)
def forward(self, x1: Variable, x2: Variable) -> Variable:
h1 = self._embed(x1)
h2 = self._embed(x2)
return F.sum(h1 * h2, axis=1)
def distributed_representation(self, word_id: np.ndarray) -> np.ndarray:
return self._embed(Variable(word_id)).data
def train(dataset: DataSet, n_epoch: int = 10, batch_size: int = 100) -> Iterator[Tuple[Word2Vec, int]]:
n_units = 100
model = Word2Vec(dataset.vocabulary.size, n_units)
optimizer = optimizers.Adam()
sampler = dataset.make_sampler()
window_size = 3
n_negative_samples = 5
def make_batch_set(indices: np.ndarray) -> Tuple[Variable, Variable, Variable]:
x1, x2, t = [], [], []
for index in indices:
id1 =[index]
for i in range(-window_size, window_size+1):
p = index + i
if i == 0 or p < 0 or p >= dataset.size:
id2 =[p]
for nid in sampler.sample(n_negative_samples):
return (Variable(np.array(x1, dtype=np.int32)),
Variable(np.array(x2, dtype=np.int32)),
Variable(np.array(t, dtype=np.int32)))
for epoch in range(n_epoch):"epoch: {}".format(epoch))
indices = np.random.permutation(dataset.size)
for i in range(0, dataset.size, batch_size):"-- {}, {}".format(epoch, i))
x1, x2, t = make_batch_set(indices[i:i+batch_size])
loss = model(x1, x2, t)
yield model, epoch
class Search:
def __init__(self, vocabulary: Vocabulary, model: Word2Vec):
word_ids = np.arange(0, vocabulary.size, dtype=np.int32)
self._vocabulary = vocabulary
self._vectors = model.distributed_representation(word_ids)
def find_similar_words(self, word: str, n: int = 10):
return self.find_similar_words_by_vector(self.get_vector(word), n)
def find_similar_words_by_vector(self, vector: np.ndarray, n: int = 10):
vocabulary = self._vocabulary
similar_ids = sorted(range(0, vocabulary.size),
key=lambda id: cosine(self._vectors[id], vector))[:n]
return [vocabulary.to_word(id) for id in similar_ids]
def get_vector(self, word: str):
id = self._vocabulary.to_id(word)
return self._vectors[id]
def save_model(dir_name: str, model: Word2Vec, epoch: int) -> None:
filename = "{}/w2v_model_epoch{}.npz".format(dir_name, epoch)
serializers.save_npz(filename, model)
def load_model(filename: str, vocabulary: Vocabulary, model_class: type = Word2Vec) -> Word2Vec:
n_units = 100 # TODO
model = model_class(vocabulary.size, n_units)
serializers.load_npz(filename, model)
return model
def to_2d_vectors(vocabulary: Vocabulary, model: Word2Vec):
word_ids = np.arange(0, vocabulary.size, dtype=np.int32)
vectors = model.distributed_representation(word_ids)
tsne = TSNE(n_components=2, verbose=3, random_state=12345)
vectors_2d = tsne.fit_transform(vectors)
return vectors_2d
def visualize(vocabulary: Vocabulary, vectors_2d: np.ndarray):
countries = ['u.s.', 'u.k.', 'italy', 'korea', 'china', 'germany', 'japan', 'france', 'russia', 'egypt']
capitals = ['washington', 'london', 'rome', 'seoul', 'beijing', 'berlin', 'tokyo', 'paris', 'moscow', 'cairo']
mask = [vocabulary.to_id(word) for word in countries + capitals]
fig, ax = plt.subplots()
target_vectors = vectors_2d[mask]
ax.scatter(target_vectors[:, 0], target_vectors[:, 1])
for i, label in enumerate(countries + capitals):
ax.annotate(label, (target_vectors[i, 0], target_vectors[i, 1]))
def run(seed: int = 12345) -> None:
dataset = DataSet("ptb.train.txt")
for model, epoch in train(dataset, n_epoch=50):
save_model("./models/v2", model, epoch)
set -eux
for v in train valid test; do
wget "${v}.txt"
This file has been truncated, but you can view the full file.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment