Skip to content

Instantly share code, notes, and snippets.

@axel-op
Last active June 24, 2020 13:27
Show Gist options
  • Save axel-op/29a09fe4793a4d03f621ba8fe07784a6 to your computer and use it in GitHub Desktop.
Save axel-op/29a09fe4793a4d03f621ba8fe07784a6 to your computer and use it in GitHub Desktop.
Word2Vec SGNS Incremental
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "incremental-word2vec.ipynb",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Fb-imYuPozEr",
"colab_type": "text"
},
"source": [
"# Word2Vec Skip-Gram with Negative Sampling : an incremental and space-saving algorithm\n",
"\n",
"Implémentation of the space-saving version of the Word2Vec SGNS algorithm described by May et al., 2017. It continously retrieves data from a tokenized Wikipedia dump found [here](http://www.llf.cnrs.fr/wikiparse).\n",
"\n",
"[May, Chandler & Duh, Kevin & Durme, Benjamin & Lall, Ashwin. (2017). \"Streaming Word Embeddings with the Space-Saving Algorithm\"](https://arxiv.org/abs/1704.07463)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "jr2b_F9WNbWV",
"colab_type": "code",
"colab": {}
},
"source": [
"from typing import Generator, List, Dict, Tuple"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-oMPNo397qap",
"colab_type": "code",
"colab": {}
},
"source": [
"from random import randint, uniform, sample\n",
"from collections import namedtuple\n",
"from math import sqrt\n",
"import tarfile\n",
"import json\n",
"import os\n",
"import time\n",
"import requests"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kxIvWUPOi860",
"colab_type": "code",
"colab": {}
},
"source": [
"# PyTorch\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim"
],
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9tjKS5zzEF31",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 291
},
"outputId": "710e1de1-098d-42db-eafa-b5e9b10a24ce"
},
"source": [
"USE_WANDB = True\n",
"if USE_WANDB:\n",
" %pip install wandb -q\n",
" import wandb\n",
" wandb.login()"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 1.4MB 3.5MB/s \n",
"\u001b[K |████████████████████████████████| 112kB 14.3MB/s \n",
"\u001b[K |████████████████████████████████| 460kB 17.6MB/s \n",
"\u001b[K |████████████████████████████████| 102kB 9.5MB/s \n",
"\u001b[K |████████████████████████████████| 102kB 9.0MB/s \n",
"\u001b[K |████████████████████████████████| 71kB 7.0MB/s \n",
"\u001b[K |████████████████████████████████| 71kB 7.5MB/s \n",
"\u001b[?25h Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"application/javascript": [
"\n",
" window._wandbApiKey = new Promise((resolve, reject) => {\n",
" function loadScript(url) {\n",
" return new Promise(function(resolve, reject) {\n",
" let newScript = document.createElement(\"script\");\n",
" newScript.onerror = reject;\n",
" newScript.onload = resolve;\n",
" document.body.appendChild(newScript);\n",
" newScript.src = url;\n",
" });\n",
" }\n",
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n",
" const iframe = document.createElement('iframe')\n",
" iframe.style.cssText = \"width:0;height:0;border:none\"\n",
" document.body.appendChild(iframe)\n",
" const handshake = new Postmate({\n",
" container: iframe,\n",
" url: 'https://app.wandb.ai/authorize'\n",
" });\n",
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n",
" handshake.then(function(child) {\n",
" child.on('authorize', data => {\n",
" clearTimeout(timeout)\n",
" resolve(data)\n",
" });\n",
" });\n",
" })\n",
" });\n",
" "
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[32m\u001b[41mERROR\u001b[0m Not authenticated. Copy a key from https://app.wandb.ai/authorize\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"API Key: ··········\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_FqXEmmQO2LW",
"colab_type": "text"
},
"source": [
"## Hyperparameters"
]
},
{
"cell_type": "code",
"metadata": {
"id": "1yfLy7R0O6Qb",
"colab_type": "code",
"colab": {}
},
"source": [
"# Mikolov et al., 2013:\n",
"# K = 2 ~ 5 works for large data sets,\n",
"# K = 5 ~ 20 for small data sets.\n",
"NEGATIVE_SAMPLES = 3\n",
"CONTEXT_SIZE = 3\n",
"EMBEDDING_DIMENSIONS = 100\n",
"\n",
"# Specific to incremental version\n",
"\n",
"# Max size of the vocabulary\n",
"# May et al., 2017:\n",
"# (with K = 70000) the relative error is around one or two for many words\n",
"K = 100000\n",
"SUBSAMPLING_THRESHOLD = 0.001"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "u1WCYn9NiEdt",
"colab_type": "text"
},
"source": [
"## Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PAiWw25uiGJf",
"colab_type": "code",
"colab": {}
},
"source": [
"class SkipGram(nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(SkipGram, self).__init__()\n",
" self.use_cuda = False\n",
" self.embeddings1 = nn.Embedding(K, EMBEDDING_DIMENSIONS, sparse=True)\n",
" self.embeddings2 = nn.Embedding(K, EMBEDDING_DIMENSIONS, sparse=True)\n",
"\n",
" # TODO: initialiser les embeddings avec des valeurs aléatoires\n",
"\n",
" def with_cuda(self):\n",
" self.use_cuda = True\n",
" return self.cuda()\n",
"\n",
" def forward(\n",
" self, \n",
" targets: List[int],\n",
" context_words: List[int],\n",
" negative_samples: List[List[int]]\n",
" ):\n",
" def create_tensor(values):\n",
" return torch.cuda.LongTensor(values) if self.use_cuda else torch.LongTensor(values)\n",
" \n",
" targets = create_tensor(targets)\n",
" context_words = create_tensor(context_words)\n",
" negative_samples = create_tensor(negative_samples)\n",
"\n",
" hiddens = self.embeddings1(targets)\n",
"\n",
" positive_embeddings = self.embeddings2(context_words)\n",
" negative_embeddings = self.embeddings2(negative_samples)\n",
"\n",
" out_positive = (hiddens * positive_embeddings).sum(-1)\n",
" out_positive = F.logsigmoid(out_positive)\n",
"\n",
" out_negative = (hiddens.unsqueeze(-2) * negative_embeddings).sum(-1)\n",
" out_negative = F.logsigmoid(-1 * out_negative)\n",
"\n",
" loss = -1 * (torch.sum(out_positive) + torch.sum(out_negative))\n",
" return loss\n",
"\n",
" def reinit_embeddings(self, index):\n",
" '''\n",
" \"Whenever a word w is ejected from the spacesaving data structure\n",
" its target-word embedding vk and context-word embedding v'k\n",
" are re-initialized as draws from N(0, 1)D.\n",
" This means that if w appears in the future and is inserted again into space-saving data structure,\n",
" training of its embedding starts over from scratch.\"\n",
" '''\n",
" with torch.no_grad():\n",
" self.embeddings1.weight[index].normal_(0, 1)\n",
" self.embeddings2.weight[index].normal_(0, 1)\n"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zT84unrkIq-M",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_new_model():\n",
" model = SkipGram()\n",
" if torch.cuda.is_available():\n",
" model = model.with_cuda()\n",
" return model"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "tjXtS9QBO_jU",
"colab_type": "text"
},
"source": [
"## Parsing of data\n",
"\n",
"These functions return generators so that the data can be parsed and retrieved continuously."
]
},
{
"cell_type": "code",
"metadata": {
"id": "ntwT88xDMwhT",
"colab_type": "code",
"colab": {}
},
"source": [
"def keep_token(splitted) -> bool:\n",
" text = splitted[1]\n",
" cat = splitted[4]\n",
" # Exclure ou conserver la ponctuation ?\n",
" # if cat == 'PONCT'\n",
" return text != '-LRB-' \\\n",
" and text != '-RRB-'\n",
"\n",
"def keep_sentence(sentence) -> bool:\n",
" return len(sentence) != 0 \\\n",
" and '↑' not in sentence # Footnotes\n",
"\n",
"def extract_sentences(file) -> Generator[List[str], None, None]:\n",
" current = []\n",
" for line in file:\n",
" line = line.decode('utf-8').strip()\n",
" splitted = line.split()\n",
" if len(splitted) > 1:\n",
" if keep_token(splitted):\n",
" text = splitted[1]\n",
" current.append(text)\n",
" else:\n",
" if keep_sentence(current):\n",
" yield current\n",
" current = []\n",
" if keep_sentence(current):\n",
" yield current"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gY3e6wsQHrXR",
"colab_type": "code",
"colab": {}
},
"source": [
"ARCHIVE_URL = 'http://www.llf.cnrs.fr/wikiparse/enwiki.tar.bz2'\n",
"\n",
"def get_sentences() -> Generator[List[str], None, None]:\n",
" stream = requests.get(ARCHIVE_URL, stream=True)\n",
" tar = tarfile.open(fileobj=stream.raw, mode='r|bz2')\n",
" for m in tar:\n",
" if m.isfile() and m.name.endswith('.conll'):\n",
" file = tar.extractfile(m)\n",
" for s in extract_sentences(file):\n",
" yield s"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "g9bSjfm0SXBS",
"colab_type": "text"
},
"source": [
"## Implementation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "g2ufbIDnSZNH",
"colab_type": "code",
"colab": {}
},
"source": [
"Smallest = namedtuple('Smallest', ['count', 'i2w'])\n",
"\n",
"class Vocabulary(object):\n",
" def __init__(self):\n",
" self.w2i = dict()\n",
" self.i2c = dict()\n",
" self.size = 0\n",
" self.smallest = Smallest(1, dict())\n",
" self.reservoir = [0 for k in range(K)]\n",
" self.n = 0"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "gesH3KHWTa9G",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_examples(model, vocab) -> Generator[Tuple[int, int], None, None]:\n",
" w2i = vocab.w2i\n",
" i2c = vocab.i2c\n",
" reservoir = vocab.reservoir\n",
" for sentence in get_sentences():\n",
" # Subsampling\n",
" subsampled = []\n",
" subsampled_size = 0\n",
" for token in sentence:\n",
" if token in w2i and uniform(0.0, 1.0) > min(1, sqrt(SUBSAMPLING_THRESHOLD / vocab.size)):\n",
" continue\n",
" subsampled.append(token)\n",
" subsampled_size += 1\n",
"\n",
" # Update the vocabulary\n",
" # (complexités à revoir)\n",
" if token in w2i:\n",
" i = w2i[token]\n",
" i2c[i] += 1\n",
" if i in vocab.smallest.i2w:\n",
" if len(vocab.smallest.i2w) > 1:\n",
" del vocab.smallest.i2w[i]\n",
" else:\n",
" vocab.smallest._replace(count = i2c[i])\n",
" elif vocab.size < K:\n",
" i = vocab.size\n",
" w2i[token] = i\n",
" i2c[i] = 1\n",
" vocab.size += 1\n",
" if vocab.smallest.count > 1:\n",
" vocab.smallest._replace(count = 1)\n",
" vocab.smallest.i2w.clear()\n",
" vocab.smallest.i2w[i] = token\n",
" else:\n",
" i = next(iter(vocab.smallest.i2w.keys()))\n",
" w2i[token] = i\n",
" i2c[i] += 1\n",
" del w2i[vocab.smallest.i2w[i]]\n",
" model.reinit_embeddings(i)\n",
" if len(vocab.smallest.i2w) > 1:\n",
" del vocab.smallest.i2w[i]\n",
" else:\n",
" vocab.smallest._replace(count = i2c[i])\n",
" vocab.smallest.i2w[i] = token\n",
"\n",
" # Update the reservoir distribution\n",
" vocab.n += 1\n",
" if vocab.n <= K:\n",
" reservoir.append(w2i[token])\n",
" else:\n",
" k = randint(1, vocab.n)\n",
" if k <= K:\n",
" reservoir[k] = w2i[token]\n",
" \n",
" # \"When a sentence smaller (after subsampling) than a single context window is encountered,\n",
" # its words are first added to the space-saving data structure and negative sampling reservoir,\n",
" # and then the algorithm moves on to the next sentence\"\n",
" if subsampled_size < CONTEXT_SIZE * 2 + 1:\n",
" continue\n",
" \n",
" subsampled = [vocab.w2i.get(t, None) for t in subsampled]\n",
" left = 0\n",
" right = left + CONTEXT_SIZE * 2 + 1\n",
" while right <= subsampled_size:\n",
" examples = []\n",
" target = subsampled[left + CONTEXT_SIZE]\n",
" if target is None:\n",
" left += CONTEXT_SIZE\n",
" else:\n",
" for i in range(left, right):\n",
" if i == left + CONTEXT_SIZE:\n",
" continue\n",
" c = subsampled[i]\n",
" if c is None:\n",
" left = i\n",
" examples.clear()\n",
" break\n",
" examples.append((target, c))\n",
" yield from examples\n",
" left += 1\n",
" right = left + CONTEXT_SIZE * 2 + 1\n"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "jDMJ_zdG99vx",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_negative_samples(vocab, k: int, input_, context):\n",
" reservoir = vocab.reservoir\n",
" indexes = sample(range(K), k + 2)\n",
" samples = []\n",
" for i in indexes[:-2]:\n",
" w_i = reservoir[i]\n",
" if w_i == input_:\n",
" w_i = reservoir[indexes[-1]]\n",
" elif w_i == context:\n",
" w_i = reservoir[indexes[-2]]\n",
" samples.append(w_i)\n",
" return samples"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QEyINh0H_Qzo",
"colab_type": "code",
"colab": {}
},
"source": [
"def train(\n",
" model,\n",
" vocab,\n",
" learning_rate: float = 0.01,\n",
" max_examples: int = None\n",
"):\n",
" log_interval = 1000\n",
" if USE_WANDB:\n",
" wandb.init(project='word2vec-sgns-incremental', reinit=True)\n",
" wandb.watch(model)\n",
" wandb.run.summary.update({\n",
" \"learning_rate\": learning_rate,\n",
" \"log_interval\": log_interval})\n",
" \n",
" optimizer = optim.SGD(model.parameters(), lr=learning_rate)\n",
" start_ex = time.time()\n",
" for ex_count, ex in enumerate(get_examples(model, vocab)):\n",
" model.zero_grad()\n",
" input_, context = ex\n",
" neg_samples = get_negative_samples(vocab, NEGATIVE_SAMPLES, input_, context)\n",
" loss = model([input_], [context], neg_samples)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print(f'\\r{ex_count} examples trained', end='')\n",
" if USE_WANDB and ex_count % log_interval == 0:\n",
" wandb.log({\n",
" \"example\": ex_count,\n",
" \"loss\": loss,\n",
" \"time per example\": (time.time() - start_ex) / log_interval})\n",
" start_ex = time.time()\n",
"\n",
" if max_examples is not None and ex_count >= max_examples:\n",
" break\n",
" \n",
" if USE_WANDB:\n",
" torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'model.pt'))\n",
" json.dump({i: w for w, i in w2i.items()},\n",
" open(os.path.join(wandb.run.dir, 'i2w.json'), 'w'))"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "LLrgWhNuv0Oh",
"colab_type": "text"
},
"source": [
"## Evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-KCBQlK7sP-_",
"colab_type": "code",
"colab": {}
},
"source": [
"TEST_SET = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/word-test.v1.txt'\n",
"\n",
"def eval_model(model, w2i):\n",
" model.eval()\n",
" sim = nn.CosineSimilarity()\n",
" rows = model.embeddings1.weight.shape[0]\n",
" scores = dict()\n",
" current_group = None\n",
" # We evaluate both embeddings layers\n",
" accuracy = [0, 0]\n",
" total = 0\n",
" page = requests.get(TEST_SET).text\n",
" print('Evaluating...')\n",
" for i, line in enumerate(page.split('\\n')):\n",
" print(f'\\rline {i}', end='')\n",
" line = line.strip()\n",
"\n",
" if line.startswith('//'):\n",
" continue\n",
"\n",
" if line.startswith(':'):\n",
" if current_group is not None and total > 0:\n",
" scores[current_group] = accuracy[0] / total, accuracy[1] / total\n",
" current_group = line\n",
" accuracy = [0, 0]\n",
" total = 0\n",
" continue\n",
"\n",
" words = line.split()\n",
" if len(words) != 4 or any(word not in w2i for word in words):\n",
" continue\n",
"\n",
" with torch.no_grad(): \n",
" words = torch.LongTensor([w2i[word] for word in words])\n",
" if torch.cuda.is_available():\n",
" words = words.cuda()\n",
"\n",
" embeds1, embeds2 = model.embeddings1(words), model.embeddings2(words)\n",
" tensor1, tensor2 = embeds1[1] - embeds1[0] + embeds1[2], \\\n",
" embeds2[1] - embeds2[0] + embeds2[2]\n",
" sim1, sim2 = sim(model.embeddings1.weight, tensor1.expand(rows, -1)), \\\n",
" sim(model.embeddings2.weight, tensor2.expand(rows, -1))\n",
"\n",
" # Similarity with the ideal vector\n",
" sim_match1, sim_match2 = sim1[words[3]], sim2[words[3]]\n",
"\n",
" # Count the number of vectors that are nearest than the ideal one\n",
" accuracy[0] += (rows - torch.sum(sim1 > sim_match1).item()) / rows\n",
" accuracy[1] += (rows - torch.sum(sim2 > sim_match2).item()) / rows\n",
" \n",
" total += 1\n",
" return scores"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "ezXJ-nHYr8BM",
"colab_type": "text"
},
"source": [
"## Runs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ROXfbWoUBcWJ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 411
},
"outputId": "8d21375b-1a6f-4ece-a49f-d04a16a7e60b"
},
"source": [
"vocab = Vocabulary()\n",
"model = get_new_model()\n",
"train(model, vocab, learning_rate=0.1, max_examples=300000)"
],
"execution_count": 17,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental</a><br/>\n",
" Run page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/1mqb1ms1\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/1mqb1ms1</a><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"191585 examples trained"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-17-bf9f8fb43f41>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mvocab\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVocabulary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_new_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m300000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-13-8e17766134aa>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, vocab, learning_rate, max_examples)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mstart_ex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mex_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-16-6b42728f14dd>\u001b[0m in \u001b[0;36mget_examples\u001b[0;34m(model, vocab)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0mw2i\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0mi2c\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "HmAe9B6rsUVZ",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 257
},
"outputId": "b365c35b-1768-46a0-c37d-99411ed86b3b"
},
"source": [
"# \"Accuracies\" of the two embeddings layers per group ([0, 1], higher is better)\n",
"eval_model(model, vocab.w2i)"
],
"execution_count": 22,
"outputs": [
{
"output_type": "stream",
"text": [
"Evaluating...\n",
"line 19559"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{': capital-common-countries': (0.4599991190476193, 0.500064857142857),\n",
" ': capital-world': (0.4714395522388058, 0.5043045842217491),\n",
" ': city-in-state': (0.48731167601246084, 0.4845760560747671),\n",
" ': currency': (0.38896300000000006, 0.5298366666666665),\n",
" ': family': (0.4371339181286551, 0.5607295906432747),\n",
" ': gram1-adjective-to-adverb': (0.49082139999999924, 0.5039669000000002),\n",
" ': gram2-opposite': (0.5106824264705883, 0.5056287132352941),\n",
" ': gram3-comparative': (0.5161551825396833, 0.5025758650793655),\n",
" ': gram4-superlative': (0.4350813450292397, 0.4920672222222224),\n",
" ': gram5-present-participle': (0.5449368121693123, 0.4661682936507937),\n",
" ': gram6-nationality-adjective': (0.5093566885485051, 0.5011422902990518),\n",
" ': gram7-past-tense': (0.4716178228228234, 0.534154241741741),\n",
" ': gram8-plural': (0.4897098677248677, 0.47332212962962994)}"
]
},
"metadata": {
"tags": []
},
"execution_count": 22
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LetEBcQIxDay",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 411
},
"outputId": "370d20d6-cb16-40e5-8a02-a3bf98474b97"
},
"source": [
"vocab = Vocabulary()\n",
"model = get_new_model()\n",
"train(model, vocab, learning_rate=0.01, max_examples=200000)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental</a><br/>\n",
" Run page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/3cp1zhsx\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/3cp1zhsx</a><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"191489 examples trained"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-59-2fa73f61d44a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0minit_vocabulary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_new_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m200000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-57-3b8043fd80c2>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, learning_rate, max_examples)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mstart_ex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mex_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-55-c3e3cf5ee710>\u001b[0m in \u001b[0;36mget_examples\u001b[0;34m(model)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mglobal\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mglobal\u001b[0m \u001b[0msmallest_count\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0msentence\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mget_sentences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;31m# Subsampling\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0msubsampled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-53-0738f64dd33d>\u001b[0m in \u001b[0;36mget_sentences\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'.conll'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mfile\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextractfile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mextract_sentences\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32myield\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-52-c0b4cf4a49ba>\u001b[0m in \u001b[0;36mextract_sentences\u001b[0;34m(file)\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mcurrent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mline\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mfile\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mline\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstrip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 15\u001b[0m \u001b[0msplitted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplitted\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "sAIjSb4Exlpy",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 257
},
"outputId": "4cb7763b-f8b0-495a-e9c7-58078237e6d4"
},
"source": [
"eval_model(model, vocab.w2i)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Evaluating...\n",
"line 19559"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{': capital-common-countries': (0.47684664285714295, 0.5054748571428571),\n",
" ': capital-world': (0.4861371215351811, 0.49943603411513854),\n",
" ': city-in-state': (0.5144588598130845, 0.5082236573208725),\n",
" ': currency': (0.5027079999999999, 0.37295466666666666),\n",
" ': family': (0.4390372514619884, 0.4730166374269009),\n",
" ': gram1-adjective-to-adverb': (0.4913326000000002, 0.4875859666666661),\n",
" ': gram2-opposite': (0.4850392279411762, 0.4961184926470588),\n",
" ': gram3-comparative': (0.4690430634920641, 0.5247969365079361),\n",
" ': gram4-superlative': (0.4813781578947369, 0.49429385964912276),\n",
" ': gram5-present-participle': (0.5161127645502649, 0.4706116005291002),\n",
" ': gram6-nationality-adjective': (0.492469124726477, 0.5012207512764405),\n",
" ': gram7-past-tense': (0.5005979654654655, 0.46838213213213187),\n",
" ': gram8-plural': (0.48234771164021123, 0.49860328042328017)}"
]
},
"metadata": {
"tags": []
},
"execution_count": 60
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WbIFhhHvJd9a",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 411
},
"outputId": "829c6d2c-ee1a-4b8c-db86-e9656defb249"
},
"source": [
"vocab = Vocabulary()\n",
"model = get_new_model()\n",
"train(model, vocab, learning_rate=0.001, max_examples=300000)"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental</a><br/>\n",
" Run page: <a href=\"https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/27lr37hg\" target=\"_blank\">https://app.wandb.ai/axel-op/word2vec-sgns-incremental/runs/27lr37hg</a><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"94859 examples trained"
],
"name": "stdout"
},
{
"output_type": "error",
"ename": "KeyboardInterrupt",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-87-eb00eb73ec27>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mvocab\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVocabulary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_new_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_examples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m300000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-84-8e17766134aa>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(model, vocab, learning_rate, max_examples)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mstart_ex\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mex_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_examples\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0minput_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontext\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-86-f4eeb41f8f89>\u001b[0m in \u001b[0;36mget_examples\u001b[0;34m(model, vocab)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtoken\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msmallest\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2w\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mw2i\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtoken\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0mvocab\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mi2c\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Ja07ZR1rJi39",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "7ffd2ca4-93a7-4b38-e7e8-29f3bde487d3"
},
"source": [
"eval_model(model, vocab.w2i)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Evaluating...\n",
"line 14711"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment