Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save RottenFruits/d20f1c4fe81bfe08359e0d733a12a98a to your computer and use it in GitHub Desktop.
Save RottenFruits/d20f1c4fe81bfe08359e0d733a12a98a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# メモ"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#参考\n",
"#https://towardsdatascience.com/implementing-word2vec-in-pytorch-skip-gram-model-e6bae040d2fb\n",
"#https://github.com/fanglanting/skip-gram-pytorch\n",
"\n",
"#movie lens使って実際に見てみる、blogに記事書く"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# import"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.autograd import Variable\n",
"import pandas as pd\n",
"import numpy as np\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# prepro"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = [\n",
" 'he is a king',\n",
" 'she is a queen',\n",
" 'he is a man',\n",
" 'she is a woman',\n",
" 'warsaw is poland capital',\n",
" 'berlin is germany capital',\n",
" 'paris is france capital',\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Corpus:\n",
" def __init__(self, corpus):\n",
" self.corpus = corpus\n",
" \n",
" self.tokenized_corpus = self.tokenize_corpus(self.corpus)\n",
" self.vocabulary = self.get_vocabulary(self.tokenized_corpus)\n",
" self.vocab_size = len(self.vocabulary)\n",
" self.word2idx = {w: idx for (idx, w) in enumerate(self.vocabulary)}\n",
" self.idx2word = {idx: w for (idx, w) in enumerate(self.vocabulary)}\n",
" \n",
" def tokenize_corpus(self, corpus):\n",
" tokens = [x.split() for x in corpus]\n",
" return tokens\n",
" \n",
" def get_vocabulary(self, tokenized_corpus):\n",
" vocabulary = []\n",
" for sentence in tokenized_corpus:\n",
" for token in sentence:\n",
" if token not in vocabulary:\n",
" vocabulary.append(token)\n",
" return vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"corpus = Corpus(data)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[['he', 'is', 'a', 'king'],\n",
" ['she', 'is', 'a', 'queen'],\n",
" ['he', 'is', 'a', 'man'],\n",
" ['she', 'is', 'a', 'woman'],\n",
" ['warsaw', 'is', 'poland', 'capital'],\n",
" ['berlin', 'is', 'germany', 'capital'],\n",
" ['paris', 'is', 'france', 'capital']]"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"corpus.tokenized_corpus"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"class Skipgram(torch.nn.Module):\n",
" def __init__(self, vocab_size, embedding_dim):\n",
" super(Skipgram, self).__init__()\n",
" self.embedding_dim = embedding_dim\n",
" self.vocab_size = vocab_size\n",
" \n",
" self.u_embeddings = torch.nn.Embedding(self.vocab_size, self.embedding_dim, sparse = True)\n",
" self.v_embeddings = torch.nn.Embedding(self.embedding_dim, self.vocab_size, sparse = True) \n",
" \n",
" def forward(self, batch):\n",
" y_true = Variable(torch.from_numpy(np.array([batch[1]])).long())\n",
" \n",
" x1 = torch.LongTensor([[batch[0]]])\n",
" x2 = torch.LongTensor([range(self.embedding_dim)]) \n",
" u_emb = self.u_embeddings(x1)\n",
" v_emb = self.v_embeddings(x2)\n",
" z = torch.matmul(u_emb, v_emb).view(-1) #view reshape\n",
"\n",
" log_softmax = F.log_softmax(z, dim=0)\n",
" loss = F.nll_loss(log_softmax.view(1,-1), y_true)\n",
" return loss\n",
" \n",
" def generate_batch(self, corpus, window_size):\n",
" idx_pairs = [] \n",
" for sentence in corpus.tokenized_corpus:\n",
" indices = [corpus.word2idx[word] for word in sentence]\n",
" for center_word_pos in range(len(indices)):\n",
" for w in range(-window_size, window_size + 1):\n",
" context_word_pos = center_word_pos + w\n",
" if context_word_pos < 0 or context_word_pos >= len(indices) or center_word_pos == context_word_pos:\n",
" continue\n",
" context_word_idx = indices[context_word_pos]\n",
" idx_pairs.append((indices[center_word_pos], context_word_idx))\n",
"\n",
" idx_pairs = np.array(idx_pairs) \n",
" return idx_pairs"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"#distributed representation\n",
"class DistributedRepresentation:\n",
" def __init__(self, corpus, window_size, embedding_dim, mode_type = 1):\n",
" self.corpus = corpus\n",
" self.window_size = window_size\n",
" self.embedding_dim = embedding_dim\n",
"\n",
" #model set\n",
" if mode_type == 1:\n",
" self.model = Skipgram(self.corpus.vocab_size, self.embedding_dim)\n",
" elif mode_type == 2:\n",
" print(\"2\")\n",
" \n",
" def train(self, num_epochs = 100, learning_rate = 0.001):\n",
" optimizer = optim.SGD(self.model.parameters(), lr = learning_rate)\n",
" for epo in range(num_epochs):\n",
" loss_val = 0\n",
" batches = self.model .generate_batch(self.corpus, window_size)\n",
" \n",
" for batch in batches: \n",
" optimizer.zero_grad()\n",
" loss = self.model(batch)\n",
" loss.backward()\n",
" loss_val += loss.data\n",
" optimizer.step()\n",
" \n",
" if epo % 10 == 0: \n",
" print(f'Loss at epo {epo}: {loss_val/len(batches)}')\n",
" \n",
" def get_vector(self, word):\n",
" word_idx = self.corpus.word2idx[word]\n",
" word_idx = torch.LongTensor([[ word_idx]])\n",
" \n",
" vector = self.model.u_embeddings(word_idx).view(-1).detach().numpy()\n",
" return vector"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# train"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss at epo 0: 3.3020033836364746\n",
"Loss at epo 10: 2.71722674369812\n",
"Loss at epo 20: 2.411458969116211\n",
"Loss at epo 30: 2.163210391998291\n",
"Loss at epo 40: 1.9434529542922974\n",
"Loss at epo 50: 1.7515023946762085\n",
"Loss at epo 60: 1.6070774793624878\n",
"Loss at epo 70: 1.5127593278884888\n",
"Loss at epo 80: 1.452197551727295\n",
"Loss at epo 90: 1.4114344120025635\n",
"Loss at epo 100: 1.3822027444839478\n"
]
}
],
"source": [
"window_size = 1\n",
"embedding_dims = 3\n",
"\n",
"DR = DistributedRepresentation(corpus, window_size , embedding_dims, mode_type = 1)\n",
"DR.train(num_epochs = 101, learning_rate = 0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DR.get_vector('he')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment