Skip to content

Instantly share code, notes, and snippets.

@RottenFruits
Created March 3, 2019 08:19
Show Gist options
  • Save RottenFruits/cfb3f92c294dc2467edc0ecf4d0ed8e1 to your computer and use it in GitHub Desktop.
Save RottenFruits/cfb3f92c294dc2467edc0ecf4d0ed8e1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#参考\n",
"#https://towardsdatascience.com/implementing-word2vec-in-pytorch-skip-gram-model-e6bae040d2fb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# import"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch.functional as F\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# prepro"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"corpus = [\n",
" 'he is a king',\n",
" 'she is a queen',\n",
" 'he is a man',\n",
" 'she is a woman',\n",
" 'warsaw is poland capital',\n",
" 'berlin is germany capital',\n",
" 'paris is france capital',\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"class skipgram(torch.nn.Module):\n",
" def __init__(self, corpus, window_size, embedding_dim):\n",
" self.corpus = corpus\n",
" self.embedding_dim = embedding_dim\n",
" \n",
" #treat corpus\n",
" tokenized_corpus = self.tokenize_corpus(self.corpus)\n",
" self.vocabulary = self.get_vocabulary(tokenized_corpus)\n",
" self.word2idx = {w: idx for (idx, w) in enumerate(self.vocabulary)}\n",
" self.idx2word = {idx: w for (idx, w) in enumerate(self.vocabulary)}\n",
" self.idx_pairs = self.positive_pair(window_size, tokenized_corpus, self.word2idx)\n",
"\n",
" self.vocab_size = len(self.vocabulary)\n",
" self.init_embedding_layer()\n",
" \n",
" def tokenize_corpus(self, corpus):\n",
" tokens = [x.split() for x in corpus]\n",
" return tokens\n",
" \n",
" def get_vocabulary(self, tokenized_corpus):\n",
" vocabulary = []\n",
" for sentence in tokenized_corpus:\n",
" for token in sentence:\n",
" if token not in vocabulary:\n",
" vocabulary.append(token)\n",
" return vocabulary\n",
"\n",
" def positive_pair(self, window_size, tokenized_corpus, word2idx):\n",
" idx_pairs = [] \n",
" #create positive pair\n",
" for sentence in tokenized_corpus:\n",
" indices = [word2idx[word] for word in sentence]\n",
" for center_word_pos in range(len(indices)):\n",
" for w in range(-window_size, window_size + 1):\n",
" context_word_pos = center_word_pos + w\n",
" if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:\n",
" continue\n",
" context_word_idx = indices[context_word_pos]\n",
" idx_pairs.append((indices[center_word_pos], context_word_idx))\n",
"\n",
" idx_pairs = np.array(idx_pairs) \n",
" return idx_pairs\n",
" \n",
" def init_embedding_layer(self):\n",
" self.u_embeddings = Variable(torch.randn(self.embedding_dim, self.vocab_size).float(), requires_grad = True)\n",
" self.v_embeddings = Variable(torch.randn(self.vocab_size, self.embedding_dim).float(), requires_grad = True)\n",
" \n",
" def get_input_layer(self, word_idx):\n",
" x = torch.zeros(self.vocab_size).float()\n",
" x[word_idx] = 1.0\n",
" return x\n",
" \n",
" def get_vector(self, word):\n",
" word_idx = self.word2idx[word]\n",
" x = torch.zeros(self.vocab_size).float()\n",
" x[word_idx] = 1.0\n",
" vector = torch.matmul(self.u_embeddings, x).detach().numpy()\n",
" return vector\n",
" \n",
" def train(self, num_epochs = 100, learning_rate = 0.001):\n",
" for epo in range(num_epochs):\n",
" loss_val = 0\n",
" for data, target in self.idx_pairs:\n",
" x = Variable(self.get_input_layer(data)).float()\n",
" y_true = Variable(torch.from_numpy(np.array([target])).long())\n",
"\n",
" z1 = torch.matmul(self.u_embeddings, x)\n",
" z2 = torch.matmul(self.v_embeddings, z1)\n",
"\n",
" log_softmax = F.log_softmax(z2, dim=0)\n",
"\n",
" loss = F.nll_loss(log_softmax.view(1,-1), y_true)\n",
" loss_val += loss.data\n",
" loss.backward()\n",
" self.u_embeddings.data -= learning_rate * self.u_embeddings.grad.data\n",
" self.v_embeddings.data -= learning_rate * self.v_embeddings.grad.data\n",
"\n",
" self.u_embeddings.grad.data.zero_()\n",
" self.v_embeddings.grad.data.zero_()\n",
" \n",
" if epo % 100 == 0: \n",
" print(f'Loss at epo {epo}: {loss_val/len(self.idx_pairs)}')"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss at epo 0: 3.978736639022827\n",
"Loss at epo 100: 1.8136866092681885\n",
"Loss at epo 200: 1.6931017637252808\n",
"Loss at epo 300: 1.652329444885254\n",
"Loss at epo 400: 1.635833501815796\n",
"Loss at epo 500: 1.6262073516845703\n",
"Loss at epo 600: 1.6190433502197266\n",
"Loss at epo 700: 1.6133298873901367\n",
"Loss at epo 800: 1.6085796356201172\n",
"Loss at epo 900: 1.6043823957443237\n",
"Loss at epo 1000: 1.6007065773010254\n"
]
}
],
"source": [
"window_size = 2\n",
"embedding_dims = 4\n",
"\n",
"model = skipgram(corpus, window_size , embedding_dims)\n",
"model.train(num_epochs = 1001, learning_rate = 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['he',\n",
" 'is',\n",
" 'a',\n",
" 'king',\n",
" 'she',\n",
" 'queen',\n",
" 'man',\n",
" 'woman',\n",
" 'warsaw',\n",
" 'poland',\n",
" 'capital',\n",
" 'berlin',\n",
" 'germany',\n",
" 'paris',\n",
" 'france']"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.96737313"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v1 = model.get_vector('queen')\n",
"v2 = model.get_vector('king')\n",
"\n",
"np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment