Last active
June 24, 2020 13:27
-
-
Save axel-op/29a09fe4793a4d03f621ba8fe07784a6 to your computer and use it in GitHub Desktop.
Word2Vec SGNS Incremental
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "incremental-word2vec.ipynb", | |
"provenance": [], | |
"toc_visible": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "Fb-imYuPozEr", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Word2Vec Skip-Gram with Negative Sampling : an incremental and space-saving algorithm\n", | |
"\n", | |
"Implémentation of the space-saving version of the Word2Vec SGNS algorithm described by May et al., 2017. It continously retrieves data from a tokenized Wikipedia dump found [here](http://www.llf.cnrs.fr/wikiparse).\n", | |
"\n", | |
"[May, Chandler & Duh, Kevin & Durme, Benjamin & Lall, Ashwin. (2017). \"Streaming Word Embeddings with the Space-Saving Algorithm\"](https://arxiv.org/abs/1704.07463)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "jr2b_F9WNbWV", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from typing import Generator, List, Dict, Tuple" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-oMPNo397qap", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from random import randint, uniform, sample\n", | |
"from collections import namedtuple\n", | |
"from math import sqrt\n", | |
"import tarfile\n", | |
"import json\n", | |
"import os\n", | |
"import time\n", | |
"import requests" | |
], | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kxIvWUPOi860", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# PyTorch\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim" | |
], | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9tjKS5zzEF31", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 291 | |
}, | |
"outputId": "710e1de1-098d-42db-eafa-b5e9b10a24ce" | |
}, | |
"source": [ | |
"USE_WANDB = True\n", | |
"if USE_WANDB:\n", | |
" %pip install wandb -q\n", | |
" import wandb\n", | |
" wandb.login()" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[K |████████████████████████████████| 1.4MB 3.5MB/s \n", | |
"\u001b[K |████████████████████████████████| 112kB 14.3MB/s \n", | |
"\u001b[K |████████████████████████████████| 460kB 17.6MB/s \n", | |
"\u001b[K |████████████████████████████████| 102kB 9.5MB/s \n", | |
"\u001b[K |████████████████████████████████| 102kB 9.0MB/s \n", | |
"\u001b[K |████████████████████████████████| 71kB 7.0MB/s \n", | |
"\u001b[K |████████████████████████████████| 71kB 7.5MB/s \n", | |
"\u001b[?25h Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/javascript": [ | |
"\n", | |
" window._wandbApiKey = new Promise((resolve, reject) => {\n", | |
" function loadScript(url) {\n", | |
" return new Promise(function(resolve, reject) {\n", | |
" let newScript = document.createElement(\"script\");\n", | |
" newScript.onerror = reject;\n", | |
" newScript.onload = resolve;\n", | |
" document.body.appendChild(newScript);\n", | |
" newScript.src = url;\n", | |
" });\n", | |
" }\n", | |
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", | |
" const iframe = document.createElement('iframe')\n", | |
" iframe.style.cssText = \"width:0;height:0;border:none\"\n", | |
" document.body.appendChild(iframe)\n", | |
" const handshake = new Postmate({\n", | |
" container: iframe,\n", | |
" url: 'https://app.wandb.ai/authorize'\n", | |
" });\n", | |
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", | |
" handshake.then(function(child) {\n", | |
" child.on('authorize', data => {\n", | |
" clearTimeout(timeout)\n", | |
" resolve(data)\n", | |
" });\n", | |
" });\n", | |
" })\n", | |
" });\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.Javascript object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Not authenticated. Copy a key from https://app.wandb.ai/authorize\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"API Key: ··········\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "_FqXEmmQO2LW", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Hyperparameters" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1yfLy7R0O6Qb", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Mikolov et al., 2013:\n", | |
"# K = 2 ~ 5 works for large data sets,\n", | |
"# K = 5 ~ 20 for small data sets.\n", | |
"NEGATIVE_SAMPLES = 3\n", | |
"CONTEXT_SIZE = 3\n", | |
"EMBEDDING_DIMENSIONS = 100\n", | |
"\n", | |
"# Specific to incremental version\n", | |
"\n", | |
"# Max size of the vocabulary\n", | |
"# May et al., 2017:\n", | |
"# (with K = 70000) the relative error is around one or two for many words\n", | |
"K = 100000\n", | |
"SUBSAMPLING_THRESHOLD = 0.001" | |
], | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "u1WCYn9NiEdt", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "PAiWw25uiGJf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"class SkipGram(nn.Module):\n", | |
"\n", | |
" def __init__(self):\n", | |
" super(SkipGram, self).__init__()\n", | |
" self.use_cuda = False\n", | |
" self.embeddings1 = nn.Embedding(K, EMBEDDING_DIMENSIONS, sparse=True)\n", | |
" self.embeddings2 = nn.Embedding(K, EMBEDDING_DIMENSIONS, sparse=True)\n", | |
"\n", | |
" # TODO: initialiser les embeddings avec des valeurs aléatoires\n", | |
"\n", | |
" def with_cuda(self):\n", | |
" self.use_cuda = True\n", | |
" return self.cuda()\n", | |
"\n", | |
" def forward(\n", | |
" self, \n", | |
" targets: List[int],\n", | |
" context_words: List[int],\n", | |
" negative_samples: List[List[int]]\n", | |
" ):\n", | |
" def create_tensor(values):\n", | |
" return torch.cuda.LongTensor(values) if self.use_cuda else torch.LongTensor(values)\n", | |
" \n", | |
" targets = create_tensor(targets)\n", | |
" context_words = create_tensor(context_words)\n", | |
" negative_samples = create_tensor(negative_samples)\n", | |
"\n", | |
" hiddens = self.embeddings1(targets)\n", | |
"\n", | |
" positive_embeddings = self.embeddings2(context_words)\n", | |
" negative_embeddings = self.embeddings2(negative_samples)\n", | |
"\n", | |
" out_positive = (hiddens * positive_embeddings).sum(-1)\n", | |
" out_positive = F.logsigmoid(out_positive)\n", | |
"\n", | |
" out_negative = (hiddens.unsqueeze(-2) * negative_embeddings).sum(-1)\n", | |
" out_negative = F.logsigmoid(-1 * out_negative)\n", | |
"\n", | |
" loss = -1 * (torch.sum(out_positive) + torch.sum(out_negative))\n", | |
" return loss\n", | |
"\n", | |
" def reinit_embeddings(self, index):\n", | |
" '''\n", | |
" \"Whenever a word w is ejected from the spacesaving data structure\n", | |
" its target-word embedding vk and context-word embedding v'k\n", | |
" are re-initialized as draws from N(0, 1)D.\n", | |
" This means that if w appears in the future and is inserted again into space-saving data structure,\n", | |
" training of its embedding starts over from scratch.\"\n", | |
" '''\n", | |
" with torch.no_grad():\n", | |
" self.embeddings1.weight[index].normal_(0, 1)\n", | |
" self.embeddings2.weight[index].normal_(0, 1)\n" | |
], | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "zT84unrkIq-M", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def get_new_model():\n", | |
" model = SkipGram()\n", | |
" if torch.cuda.is_available():\n", | |
" model = model.with_cuda()\n", | |
" return model" | |
], | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "tjXtS9QBO_jU", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Parsing of data\n", | |
"\n", | |
"These functions return generators so that the data can be parsed and retrieved continuously." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ntwT88xDMwhT", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def keep_token(splitted) -> bool:\n", | |
" text = splitted[1]\n", | |
" cat = splitted[4]\n", | |
" # Exclure ou conserver la ponctuation ?\n", | |
" # if cat == 'PONCT'\n", | |
" return text != '-LRB-' \\\n", | |
" and text != '-RRB-'\n", | |
"\n", | |
"def keep_sentence(sentence) -> bool:\n", | |
" return len(sentence) != 0 \\\n", | |
" and '↑' not in sentence # Footnotes\n", | |
"\n", | |
"def extract_sentences(file) -> Generator[List[str], None, None]:\n", | |
" current = []\n", | |
" for line in file:\n", | |
" line = line.decode('utf-8').strip()\n", | |
" splitted = line.split()\n", | |
" if len(splitted) > 1:\n", | |
" if keep_token(splitted):\n", | |
" text = splitted[1]\n", | |
" current.append(text)\n", | |
" else:\n", | |
" if keep_sentence(current):\n", | |
" yield current\n", | |
" current = []\n", | |
" if keep_sentence(current):\n", | |
" yield current" | |
], | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gY3e6wsQHrXR", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"ARCHIVE_URL = 'http://www.llf.cnrs.fr/wikiparse/enwiki.tar.bz2'\n", | |
"\n", | |
"def get_sentences() -> Generator[List[str], None, None]:\n", | |
" stream = requests.get(ARCHIVE_URL, stream=True)\n", | |
" tar = tarfile.open(fileobj=stream.raw, mode='r|bz2')\n", | |
" for m in tar:\n", | |
" if m.isfile() and m.name.endswith('.conll'):\n", | |
" file = tar.extractfile(m)\n", | |
" for s in extract_sentences(file):\n", | |
" yield s" | |
], | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "g9bSjfm0SXBS", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "g2ufbIDnSZNH", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"Smallest = namedtuple('Smallest', ['count', 'i2w'])\n", | |
"\n", | |
"class Vocabulary(object):\n", | |
" def __init__(self):\n", | |
" self.w2i = dict()\n", | |
" self.i2c = dict()\n", | |
" self.size = 0\n", | |
" self.smallest = Smallest(1, dict())\n", | |
" self.reservoir = [0 for k in range(K)]\n", | |
" self.n = 0" | |
], | |
"execution_count": 10, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "gesH3KHWTa9G", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def get_examples(model, vocab) -> Generator[Tuple[int, int], None, None]:\n", | |
" w2i = vocab.w2i\n", | |
" i2c = vocab.i2c\n", | |
" reservoir = vocab.reservoir\n", | |
" for sentence in get_sentences():\n", | |
" # Subsampling\n", | |
" subsampled = []\n", | |
" subsampled_size = 0\n", | |
" for token in sentence:\n", | |
" if token in w2i and uniform(0.0, 1.0) > min(1, sqrt(SUBSAMPLING_THRESHOLD / vocab.size)):\n", | |
" continue\n", | |
" subsampled.append(token)\n", | |
" subsampled_size += 1\n", | |
"\n", | |
" # Update the vocabulary\n", | |
" # (complexités à revoir)\n", | |
" if token in w2i:\n", | |
" i = w2i[token]\n", | |
" i2c[i] += 1\n", | |
" if i in vocab.smallest.i2w:\n", | |
" if len(vocab.smallest.i2w) > 1:\n", | |
" del vocab.smallest.i2w[i]\n", | |
" else:\n", | |
" vocab.smallest._replace(count = i2c[i])\n", | |
" elif vocab.size < K:\n", | |
" i = vocab.size\n", | |
" w2i[token] = i\n", | |
" i2c[i] = 1\n", | |
" vocab.size += 1\n", | |
" if vocab.smallest.count > 1:\n", | |
" vocab.smallest._replace(count = 1)\n", | |
" vocab.smallest.i2w.clear()\n", | |
" vocab.smallest.i2w[i] = token\n", | |
" else:\n", | |
" i = next(iter(vocab.smallest.i2w.keys()))\n", | |
" w2i[token] = i\n", | |
" i2c[i] += 1\n", | |
" del w2i[vocab.smallest.i2w[i]]\n", | |
" model.reinit_embeddings(i)\n", | |
" if len(vocab.smallest.i2w) > 1:\n", | |
" del vocab.smallest.i2w[i]\n", | |
" else:\n", | |
" vocab.smallest._replace(count = i2c[i])\n", | |
" vocab.smallest.i2w[i] = token\n", | |
"\n", | |
" # Update the reservoir distribution\n", | |
" vocab.n += 1\n", | |
" if vocab.n <= K:\n", | |
" reservoir.append(w2i[token])\n", | |
" else:\n", | |
" k = randint(1, vocab.n)\n", | |
" if k <= K:\n", | |
" reservoir[k] = w2i[token]\n", | |
" \n", | |
" # \"When a sentence smaller (after subsampling) than a single context window is encountered,\n", | |
" # its words are first added to the space-saving data structure and negative sampling reservoir,\n", | |
" # and then the algorithm moves on to the next sentence\"\n", | |
" if subsampled_size < CONTEXT_SIZE * 2 + 1:\n", | |
" continue\n", | |
" \n", | |
" subsampled = [vocab.w2i.get(t, None) for t in subsampled]\n", | |
" left = 0\n", | |
" right = left + CONTEXT_SIZE * 2 + 1\n", | |
" while right <= subsampled_size:\n", | |
" examples = []\n", | |
" target = subsampled[left + CONTEXT_SIZE]\n", | |
" if target is None:\n", | |
" left += CONTEXT_SIZE\n", | |
" else:\n", | |
" for i in range(left, right):\n", | |
" if i == left + CONTEXT_SIZE:\n", | |
" continue\n", | |
" c = subsampled[i]\n", | |
" if c is None:\n", | |
" left = i\n", | |
" examples.clear()\n", | |
" break\n", | |
" examples.append((target, c))\n", | |
" yield from examples\n", | |
" left += 1\n", | |
" right = left + CONTEXT_SIZE * 2 + 1\n" | |
], | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "jDMJ_zdG99vx", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def get_negative_samples(vocab, k: int, input_, context):\n", | |
" reservoir = vocab.reservoir\n", | |
" indexes = sample(range(K), k + 2)\n", | |
" samples = []\n", | |
" for i in indexes[:-2]:\n", | |
" w_i = reservoir[i]\n", | |
" if w_i == input_:\n", | |
" w_i = reservoir[indexes[-1]]\n", | |
" elif w_i == context:\n", | |
" w_i = reservoir[indexes[-2]]\n", | |
" samples.append(w_i)\n", | |
" return samples" | |
], | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QEyINh0H_Qzo", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def train(\n", | |
" model,\n", | |
" vocab,\n", | |
" learning_rate: float = 0.01,\n", | |
" max_examples: int = None\n", | |
"):\n", | |
" log_interval = 1000\n", | |
" if USE_WANDB:\n", | |
" wandb.init(project='word2vec-sgns-incremental', reinit=True)\n", | |
" wandb.watch(model)\n", | |
" wandb.run.summary.update({\n", | |
" \"learning_rate\": learning_rate,\n", | |
" \"log_interval\": log_interval})\n", | |
" \n", | |
" optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n", | |
" start_ex = time.time()\n", | |
" for ex_count, ex in enumerate(get_examples(model, vocab)):\n", | |
" model.zero_grad()\n", | |
" input_, context = ex\n", | |
" neg_samples = get_negative_samples(vocab, NEGATIVE_SAMPLES, input_, context)\n", | |
" loss = model([input_], [context], neg_samples)\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" print(f'\\r{ex_count} examples trained', end='')\n", | |
" if USE_WANDB and ex_count % log_interval == 0:\n", | |
" wandb.log({\n", | |
" \"example\": ex_count,\n", | |
" \"loss\": loss,\n", | |
" \"time per example\": (time.time() - start_ex) / log_interval})\n", | |
" start_ex = time.time()\n", | |
"\n", | |
" if max_examples is not None and ex_count >= max_examples:\n", | |
" break\n", | |
" \n", | |
" if USE_WANDB:\n", | |
" torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'model.pt'))\n", | |
" json.dump({i: w for w, i in w2i.items()},\n", | |
" open(os.path.join(wandb.run.dir, 'i2w.json'), 'w'))" | |
], | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "LLrgWhNuv0Oh", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Evaluation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-KCBQlK7sP-_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"TEST_SET = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/word-test.v1.txt'\n", | |
"\n", | |
"def eval_model(model, w2i):\n", | |
" model.eval()\n", | |
" sim = nn.CosineSimilarity()\n", | |
" rows = model.embeddings1.weight.shape[0]\n", | |
" scores = dict()\n", | |
" current_group = None\n", | |
" # We evaluate both embeddings layers\n", | |
" accuracy = [0, 0]\n", | |
" total = 0\n", | |
" page = requests.get(TEST_SET).text\n", | |
" print('Evaluating...')\n", | |
" for i, line in enumerate(page.split('\\n')):\n", | |
" print(f'\\rline {i}', end='')\n", | |
" line = line.strip()\n", | |
"\n", | |
" if line.startswith('//'):\n", | |
" continue\n", | |
"\n", | |
" if line.startswith(':'):\n", | |
" if current_group is not None and total > 0:\n", | |
" scores[current_group] = accuracy[0] / total, accuracy[1] / total\n", | |
" current_group = line\n", | |
" accuracy = [0, 0]\n", | |
" total = 0\n", | |
" continue\n", | |
"\n", | |
" words = line.split()\n", | |
" if len(words) != 4 or any(word not in w2i for word in words):\n", | |
" continue\n", | |
"\n", | |
" with torch.no_grad(): \n", | |
" words = torch.LongTensor([w2i[word] for word in words])\n", | |
" if torch.cuda.is_available():\n", | |
" words = words.cuda()\n", | |
"\n", | |
" embeds1, embeds2 = model.embeddings1(words), model.embeddings2(words)\n", | |
" tensor1, tensor2 = embeds1[1] - embeds1[0] + embeds1[2], \\\n", | |
" embeds2[1] - embeds2[0] + embeds2[2]\n", | |
" sim1, sim2 = sim(model.embeddings1.weight, tensor1.expand(rows, -1)), \\\n", | |
" sim(model.embeddings2.weight, tensor2.expand(rows, -1))\n", | |
"\n", | |
" # Similarity with the ideal vector\n", | |
" sim_match1, sim_match2 = sim1[words[3]], sim2[words[3]]\n", | |
"\n", | |
" # Count the number of vectors that are nearest than the ideal one\n", | |
" accuracy[0] += (rows - torch.sum(sim1 > sim_match1).item()) / rows\n", | |
" accuracy[1] += (rows - torch.sum(sim2 > sim_match2).item()) / rows\n", | |
" \n", | |
" total += 1\n", | |
" return scores" | |
], | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ezXJ-nHYr8BM", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Runs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ROXfbWoUBcWJ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 411 | |
}, | |
"outputId": "8d21375b-1a6f-4ece-a49f-d04a16a7e60b" | |
}, | |
"source": [ | |
"vocab = Vocabulary()\n", | |
"model = get_new_model()\n", | |
"train(model, vocab, learning_rate=0.1, max_examples=300000)" | |
], | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n", | |
" Project page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental</a><br/>\n", | |
" Run page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/1mqb1ms1\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/1mqb1ms1</a><br/>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"191585 examples trained" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-17-bf9f8fb43f41>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mvocab\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVocabulary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_new_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m300000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-13-8e17766134aa>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, vocab, learning_rate, max_examples)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mstart_ex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mex_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-16-6b42728f14dd>\u001b[0m in \u001b[0;36mget_examples\u001b[0;34m(model, vocab)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0mw2i\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mi2c\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "HmAe9B6rsUVZ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 257 | |
}, | |
"outputId": "b365c35b-1768-46a0-c37d-99411ed86b3b" | |
}, | |
"source": [ | |
"# \"Accuracies\" of the two embeddings layers per group ([0, 1], higher is better)\n", | |
"eval_model(model, vocab.w2i)" | |
], | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Evaluating...\n", | |
"line 19559" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{': capital-common-countries': (0.4599991190476193, 0.500064857142857),\n", | |
" ': capital-world': (0.4714395522388058, 0.5043045842217491),\n", | |
" ': city-in-state': (0.48731167601246084, 0.4845760560747671),\n", | |
" ': currency': (0.38896300000000006, 0.5298366666666665),\n", | |
" ': family': (0.4371339181286551, 0.5607295906432747),\n", | |
" ': gram1-adjective-to-adverb': (0.49082139999999924, 0.5039669000000002),\n", | |
" ': gram2-opposite': (0.5106824264705883, 0.5056287132352941),\n", | |
" ': gram3-comparative': (0.5161551825396833, 0.5025758650793655),\n", | |
" ': gram4-superlative': (0.4350813450292397, 0.4920672222222224),\n", | |
" ': gram5-present-participle': (0.5449368121693123, 0.4661682936507937),\n", | |
" ': gram6-nationality-adjective': (0.5093566885485051, 0.5011422902990518),\n", | |
" ': gram7-past-tense': (0.4716178228228234, 0.534154241741741),\n", | |
" ': gram8-plural': (0.4897098677248677, 0.47332212962962994)}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 22 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LetEBcQIxDay", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 411 | |
}, | |
"outputId": "370d20d6-cb16-40e5-8a02-a3bf98474b97" | |
}, | |
"source": [ | |
"vocab = Vocabulary()\n", | |
"model = get_new_model()\n", | |
"train(model, vocab, learning_rate=0.01, max_examples=200000)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n", | |
" Project page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental</a><br/>\n", | |
" Run page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/3cp1zhsx\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/3cp1zhsx</a><br/>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"191489 examples trained" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-59-2fa73f61d44a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0minit_vocabulary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_new_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-57-3b8043fd80c2>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, learning_rate, max_examples)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mstart_ex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mex_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-55-c3e3cf5ee710>\u001b[0m in \u001b[0;36mget_examples\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mglobal\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mglobal\u001b[0m \u001b[0msmallest_count\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0msentence\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mget_sentences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;31m# Subsampling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0msubsampled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-53-0738f64dd33d>\u001b[0m in \u001b[0;36mget_sentences\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.conll'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextractfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mextract_sentences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-52-c0b4cf4a49ba>\u001b[0m in \u001b[0;36mextract_sentences\u001b[0;34m(file)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mcurrent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mline\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0msplitted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplitted\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "sAIjSb4Exlpy", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 257 | |
}, | |
"outputId": "4cb7763b-f8b0-495a-e9c7-58078237e6d4" | |
}, | |
"source": [ | |
"eval_model(model, vocab.w2i)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Evaluating...\n", | |
"line 19559" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{': capital-common-countries': (0.47684664285714295, 0.5054748571428571),\n", | |
" ': capital-world': (0.4861371215351811, 0.49943603411513854),\n", | |
" ': city-in-state': (0.5144588598130845, 0.5082236573208725),\n", | |
" ': currency': (0.5027079999999999, 0.37295466666666666),\n", | |
" ': family': (0.4390372514619884, 0.4730166374269009),\n", | |
" ': gram1-adjective-to-adverb': (0.4913326000000002, 0.4875859666666661),\n", | |
" ': gram2-opposite': (0.4850392279411762, 0.4961184926470588),\n", | |
" ': gram3-comparative': (0.4690430634920641, 0.5247969365079361),\n", | |
" ': gram4-superlative': (0.4813781578947369, 0.49429385964912276),\n", | |
" ': gram5-present-participle': (0.5161127645502649, 0.4706116005291002),\n", | |
" ': gram6-nationality-adjective': (0.492469124726477, 0.5012207512764405),\n", | |
" ': gram7-past-tense': (0.5005979654654655, 0.46838213213213187),\n", | |
" ': gram8-plural': (0.48234771164021123, 0.49860328042328017)}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 60 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WbIFhhHvJd9a", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 411 | |
}, | |
"outputId": "829c6d2c-ee1a-4b8c-db86-e9656defb249" | |
}, | |
"source": [ | |
"vocab = Vocabulary()\n", | |
"model = get_new_model()\n", | |
"train(model, vocab, learning_rate=0.001, max_examples=300000)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n", | |
" Project page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental</a><br/>\n", | |
" Run page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/27lr37hg\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/27lr37hg</a><br/>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"94859 examples trained" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "error", | |
"ename": "KeyboardInterrupt", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-87-eb00eb73ec27>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mvocab\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVocabulary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_new_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m300000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m<ipython-input-84-8e17766134aa>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, vocab, learning_rate, max_examples)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mstart_ex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mex_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<ipython-input-86-f4eeb41f8f89>\u001b[0m in \u001b[0;36mget_examples\u001b[0;34m(model, vocab)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mw2i\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2c\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Ja07ZR1rJi39", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 51 | |
}, | |
"outputId": "7ffd2ca4-93a7-4b38-e7e8-29f3bde487d3" | |
}, | |
"source": [ | |
"eval_model(model, vocab.w2i)" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Evaluating...\n", | |
"line 14711" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment