Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save RottenFruits/6fc6e28f58a21aa17b4ef7c56b25cfe5 to your computer and use it in GitHub Desktop.
Save RottenFruits/6fc6e28f58a21aa17b4ef7c56b25cfe5 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#参考\n",
"#https://towardsdatascience.com/implementing-word2vec-in-pytorch-skip-gram-model-e6bae040d2fb"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#movie lens使って実際に見てみる、blogに記事書く"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# import"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# prepro"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"corpus = [\n",
" 'he is a king',\n",
" 'she is a queen',\n",
" 'he is a man',\n",
" 'she is a woman',\n",
" 'warsaw is poland capital',\n",
" 'berlin is germany capital',\n",
" 'paris is france capital',\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"class skipgram(torch.nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim):\n",
" super(skipgram, self).__init__()\n",
" self.embedding_dim = embedding_dim\n",
" self.vocab_size = vocab_size\n",
" self.init_embedding_layer()\n",
" \n",
" def init_embedding_layer(self):\n",
" self.u_embeddings = torch.nn.Embedding(self.vocab_size, self.embedding_dim, sparse=True)\n",
" self.v_embeddings = torch.nn.Embedding(self.embedding_dim, self.vocab_size, sparse=True)\n",
" \n",
" def forward(self, data, target):\n",
" x1 = torch.LongTensor([[data]])\n",
" x2 = torch.LongTensor([range(self.embedding_dim)])\n",
" \n",
" y_true = Variable(torch.from_numpy(np.array([target])).long())\n",
" \n",
" u_emb = self.u_embeddings(x1)\n",
" v_emb = self.v_embeddings(x2)\n",
" \n",
" z = torch.matmul(u_emb, v_emb).view(-1) #view reshape\n",
" \n",
" log_softmax = F.log_softmax(z, dim=0)\n",
" loss = F.nll_loss(log_softmax.view(1,-1), y_true)\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"#distributed representation\n",
"class DR:\n",
" def __init__(self, corpus, window_size, embedding_dim):\n",
" self.corpus = corpus\n",
" self.window_size = window_size\n",
" self.embedding_dim = embedding_dim\n",
" \n",
" #treat corpus\n",
" tokenized_corpus = self.tokenize_corpus(self.corpus)\n",
" self.vocabulary = self.get_vocabulary(tokenized_corpus)\n",
" self.vocab_size = len(self.vocabulary)\n",
" self.word2idx = {w: idx for (idx, w) in enumerate(self.vocabulary)}\n",
" self.idx2word = {idx: w for (idx, w) in enumerate(self.vocabulary)}\n",
" \n",
" #generate_batchにする\n",
" self.idx_pairs = self.positive_pair(window_size, tokenized_corpus, self.word2idx)\n",
" \n",
" #model set\n",
" self.model = skipgram(self.vocab_size, self.embedding_dim)\n",
"\n",
" def tokenize_corpus(self, corpus):\n",
" tokens = [x.split() for x in corpus]\n",
" return tokens\n",
" \n",
" def get_vocabulary(self, tokenized_corpus):\n",
" vocabulary = []\n",
" for sentence in tokenized_corpus:\n",
" for token in sentence:\n",
" if token not in vocabulary:\n",
" vocabulary.append(token)\n",
" return vocabulary\n",
"\n",
" def positive_pair(self, window_size, tokenized_corpus, word2idx):\n",
" idx_pairs = [] \n",
" #create positive pair\n",
" for sentence in tokenized_corpus:\n",
" indices = [word2idx[word] for word in sentence]\n",
" for center_word_pos in range(len(indices)):\n",
" for w in range(-window_size, window_size + 1):\n",
" context_word_pos = center_word_pos + w\n",
" if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:\n",
" continue\n",
" context_word_idx = indices[context_word_pos]\n",
" idx_pairs.append((indices[center_word_pos], context_word_idx))\n",
"\n",
" idx_pairs = np.array(idx_pairs) \n",
" return idx_pairs\n",
" \n",
" def get_vector(self, word):\n",
" word_idx = self.word2idx[word]\n",
" word_idx = torch.LongTensor([[ word_idx]])\n",
" \n",
" vector = self.model.u_embeddings(word_idx).view(-1).detach().numpy()\n",
" return vector\n",
" \n",
" def train(self, num_epochs = 100, learning_rate = 0.001):\n",
" optimizer = optim.SGD(self.model.parameters(), lr = learning_rate)\n",
" for epo in range(num_epochs):\n",
" loss_val = 0\n",
" #ここでバッチ作る\n",
" for data, target in self.idx_pairs:\n",
" optimizer.zero_grad()\n",
"\n",
" loss = self.model(data, target)\n",
" loss.backward()\n",
" loss_val += loss.data\n",
" optimizer.step()\n",
" \n",
" if epo % 10 == 0: \n",
" print(f'Loss at epo {epo}: {loss_val/len(self.idx_pairs)}')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss at epo 0: 3.6340999603271484\n",
"Loss at epo 10: 2.3864893913269043\n"
]
}
],
"source": [
"window_size = 1\n",
"embedding_dims = 5\n",
"\n",
"dr = DR(corpus, window_size , embedding_dims)\n",
"dr.train(num_epochs = 11, learning_rate = 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.5356836 , -0.8086047 , 0.26012975, -0.30699113, -0.17636605],\n",
" dtype=float32)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dr.get_vector('he')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment