Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save RottenFruits/bd32406d62a2ce293afd6fccacddc338 to your computer and use it in GitHub Desktop.
Save RottenFruits/bd32406d62a2ce293afd6fccacddc338 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#参考\n",
"#https://towardsdatascience.com/implementing-word2vec-in-pytorch-skip-gram-model-e6bae040d2fb"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# import"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch.functional as F\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# prepro"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"corpus = [\n",
" 'he is a king',\n",
" 'she is a queen',\n",
" 'he is a man',\n",
" 'she is a woman',\n",
" 'warsaw is poland capital',\n",
" 'berlin is germany capital',\n",
" 'paris is france capital',\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train"
]
},
{
"cell_type": "code",
"execution_count": 312,
"metadata": {},
"outputs": [],
"source": [
"class skipgram(torch.nn.Module):\n",
" def __init__(self, corpus, window_size, embedding_dim):\n",
" super(skipgram, self).__init__()\n",
" self.corpus = corpus\n",
" self.embedding_dim = embedding_dim\n",
" \n",
" #treat corpus\n",
" tokenized_corpus = self.tokenize_corpus(self.corpus)\n",
" self.vocabulary = self.get_vocabulary(tokenized_corpus)\n",
" self.word2idx = {w: idx for (idx, w) in enumerate(self.vocabulary)}\n",
" self.idx2word = {idx: w for (idx, w) in enumerate(self.vocabulary)}\n",
" self.idx_pairs = self.positive_pair(window_size, tokenized_corpus, self.word2idx)\n",
"\n",
" self.vocab_size = len(self.vocabulary)\n",
" self.init_embedding_layer()\n",
" \n",
" def tokenize_corpus(self, corpus):\n",
" tokens = [x.split() for x in corpus]\n",
" return tokens\n",
" \n",
" def get_vocabulary(self, tokenized_corpus):\n",
" vocabulary = []\n",
" for sentence in tokenized_corpus:\n",
" for token in sentence:\n",
" if token not in vocabulary:\n",
" vocabulary.append(token)\n",
" return vocabulary\n",
"\n",
" def positive_pair(self, window_size, tokenized_corpus, word2idx):\n",
" idx_pairs = [] \n",
" #create positive pair\n",
" for sentence in tokenized_corpus:\n",
" indices = [word2idx[word] for word in sentence]\n",
" for center_word_pos in range(len(indices)):\n",
" for w in range(-window_size, window_size + 1):\n",
" context_word_pos = center_word_pos + w\n",
" if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:\n",
" continue\n",
" context_word_idx = indices[context_word_pos]\n",
" idx_pairs.append((indices[center_word_pos], context_word_idx))\n",
"\n",
" idx_pairs = np.array(idx_pairs) \n",
" return idx_pairs\n",
" \n",
" def init_embedding_layer(self):\n",
" self.u_embeddings = torch.nn.Embedding(self.vocab_size, self.embedding_dim, sparse=True)\n",
" self.v_embeddings = torch.nn.Embedding(self.embedding_dim, self.vocab_size, sparse=True) \n",
" \n",
" def get_vector(self, word):\n",
" word_idx = self.word2idx[word]\n",
" word_idx = torch.LongTensor([[ word_idx]])\n",
" \n",
" vector = self.u_embeddings(word_idx).view(-1).detach().numpy()\n",
" return vector\n",
" \n",
" def train(self, num_epochs = 100, learning_rate = 0.001):\n",
" optimizer = optim.SGD(model.parameters(), lr = learning_rate)\n",
" \n",
" for epo in range(num_epochs):\n",
" loss_val = 0\n",
" \n",
" for data, target in self.idx_pairs:\n",
" optimizer.zero_grad()\n",
" x = torch.LongTensor([[data]])\n",
" x2 = torch.LongTensor([range(self.embedding_dim)])\n",
" \n",
" y_true = Variable(torch.from_numpy(np.array([target])).long())\n",
" \n",
" u_emb = self.u_embeddings(x)\n",
" v_emb = self.v_embeddings(x2)\n",
" \n",
" z = torch.matmul(u_emb, v_emb).view(-1) #reshape\n",
" \n",
" log_softmax = F.log_softmax(z, dim=0)\n",
" loss = F.nll_loss(log_softmax.view(1,-1), y_true)\n",
" loss_val += loss.data\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if epo % 10 == 0: \n",
" print(f'Loss at epo {epo}: {loss_val/len(self.idx_pairs)}')"
]
},
{
"cell_type": "code",
"execution_count": 333,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss at epo 0: 4.63685941696167\n",
"Loss at epo 10: 2.768066883087158\n",
"Loss at epo 20: 2.325962781906128\n",
"Loss at epo 30: 2.0715606212615967\n",
"Loss at epo 40: 1.8912194967269897\n",
"Loss at epo 50: 1.7484570741653442\n",
"Loss at epo 60: 1.6268519163131714\n",
"Loss at epo 70: 1.5213346481323242\n",
"Loss at epo 80: 1.4327272176742554\n",
"Loss at epo 90: 1.3639706373214722\n",
"Loss at epo 100: 1.3147941827774048\n",
"Loss at epo 110: 1.2804938554763794\n",
"Loss at epo 120: 1.2561253309249878\n",
"Loss at epo 130: 1.238266944885254\n",
"Loss at epo 140: 1.2247698307037354\n",
"Loss at epo 150: 1.2142980098724365\n",
"Loss at epo 160: 1.206006407737732\n",
"Loss at epo 170: 1.1993403434753418\n",
"Loss at epo 180: 1.1939210891723633\n",
"Loss at epo 190: 1.1894774436950684\n",
"Loss at epo 200: 1.1858071088790894\n",
"Loss at epo 210: 1.1827542781829834\n",
"Loss at epo 220: 1.1801972389221191\n",
"Loss at epo 230: 1.1780400276184082\n",
"Loss at epo 240: 1.176207184791565\n",
"Loss at epo 250: 1.1746381521224976\n",
"Loss at epo 260: 1.1732861995697021\n",
"Loss at epo 270: 1.1721131801605225\n",
"Loss at epo 280: 1.1710888147354126\n",
"Loss at epo 290: 1.1701886653900146\n",
"Loss at epo 300: 1.16939377784729\n",
"Loss at epo 310: 1.1686875820159912\n",
"Loss at epo 320: 1.168056845664978\n",
"Loss at epo 330: 1.1674913167953491\n",
"Loss at epo 340: 1.16698157787323\n",
"Loss at epo 350: 1.1665197610855103\n",
"Loss at epo 360: 1.1661005020141602\n",
"Loss at epo 370: 1.1657178401947021\n",
"Loss at epo 380: 1.1653674840927124\n",
"Loss at epo 390: 1.165045142173767\n",
"Loss at epo 400: 1.164748191833496\n"
]
}
],
"source": [
"window_size = 1\n",
"embedding_dims = 5\n",
"\n",
"model = skipgram(corpus, window_size , embedding_dims)\n",
"model.train(num_epochs = 401, learning_rate = 0.01)"
]
},
{
"cell_type": "code",
"execution_count": 338,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'is'"
]
},
"execution_count": 338,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 336,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-0.2428816"
]
},
"execution_count": 336,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v1 = model.get_vector('he')\n",
"v2 = model.get_vector('king')\n",
"\n",
"np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment