Skip to content

Instantly share code, notes, and snippets.

@avidale
Last active February 11, 2024 16:08
Show Gist options
  • Save avidale/c6b19687d333655da483421880441950 to your computer and use it in GitHub Desktop.
Save avidale/c6b19687d333655da483421880441950 to your computer and use it in GitHub Desktop.
bert_knn.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "bert_knn.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyNeI9VdQP6t7QxJ0kX9dq7u",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"861f120aa1514ac098d44215e7da52b7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_view_name": "HBoxView",
"_dom_classes": [],
"_model_name": "HBoxModel",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.5.0",
"box_style": "",
"layout": "IPY_MODEL_122f54d0abf147b89d781aaad085cbbd",
"_model_module": "@jupyter-widgets/controls",
"children": [
"IPY_MODEL_75061b1b409144b9948bf070976d9f61",
"IPY_MODEL_7ecbed1b02b140b68d4e8c2302a433dd"
]
}
},
"122f54d0abf147b89d781aaad085cbbd": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"75061b1b409144b9948bf070976d9f61": {
"model_module": "@jupyter-widgets/controls",
"model_name": "IntProgressModel",
"state": {
"_view_name": "ProgressView",
"style": "IPY_MODEL_a14b249a587843cf9a02fb6603aff741",
"_dom_classes": [],
"description": "100%",
"_model_name": "IntProgressModel",
"bar_style": "success",
"max": 10000,
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": 10000,
"_view_count": null,
"_view_module_version": "1.5.0",
"orientation": "horizontal",
"min": 0,
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_b43e6fe4b52c40c18334bc8b6859641b"
}
},
"7ecbed1b02b140b68d4e8c2302a433dd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_view_name": "HTMLView",
"style": "IPY_MODEL_eb490fe2acb841b58be95a3eb32182ad",
"_dom_classes": [],
"description": "",
"_model_name": "HTMLModel",
"placeholder": "",
"_view_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"value": " 10000/10000 [00:41<00:00, 240.61it/s]",
"_view_count": null,
"_view_module_version": "1.5.0",
"description_tooltip": null,
"_model_module": "@jupyter-widgets/controls",
"layout": "IPY_MODEL_a72f7f4303c94be7ae91731bf7be332d"
}
},
"a14b249a587843cf9a02fb6603aff741": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "ProgressStyleModel",
"description_width": "initial",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"bar_color": null,
"_model_module": "@jupyter-widgets/controls"
}
},
"b43e6fe4b52c40c18334bc8b6859641b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
},
"eb490fe2acb841b58be95a3eb32182ad": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_view_name": "StyleView",
"_model_name": "DescriptionStyleModel",
"description_width": "",
"_view_module": "@jupyter-widgets/base",
"_model_module_version": "1.5.0",
"_view_count": null,
"_view_module_version": "1.2.0",
"_model_module": "@jupyter-widgets/controls"
}
},
"a72f7f4303c94be7ae91731bf7be332d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_view_name": "LayoutView",
"grid_template_rows": null,
"right": null,
"justify_content": null,
"_view_module": "@jupyter-widgets/base",
"overflow": null,
"_model_module_version": "1.2.0",
"_view_count": null,
"flex_flow": null,
"width": null,
"min_width": null,
"border": null,
"align_items": null,
"bottom": null,
"_model_module": "@jupyter-widgets/base",
"top": null,
"grid_column": null,
"overflow_y": null,
"overflow_x": null,
"grid_auto_flow": null,
"grid_area": null,
"grid_template_columns": null,
"flex": null,
"_model_name": "LayoutModel",
"justify_items": null,
"grid_row": null,
"max_height": null,
"align_content": null,
"visibility": null,
"align_self": null,
"height": null,
"min_height": null,
"padding": null,
"grid_auto_rows": null,
"grid_gap": null,
"max_width": null,
"order": null,
"_view_module_version": "1.2.0",
"grid_template_areas": null,
"object_position": null,
"object_fit": null,
"grid_auto_columns": null,
"margin": null,
"display": null,
"left": null
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/avidale/c6b19687d333655da483421880441950/bert_knn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "w4Q6usoSwUh6",
"colab_type": "text"
},
"source": [
"In this notebook, we demonstrate how to extract and lookup for contextually-most-similar words using BERT and nearest neighbor search. \n",
"\n",
"This was inspired by the StackOverflow question https://stackoverflow.com/questions/59865719/how-to-find-the-closest-word-to-a-vector-using-bert"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Ul0lMapo0T3",
"colab_type": "text"
},
"source": [
"# learn to extract embeddings from bert"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xTeM3tJxhDd6",
"colab_type": "text"
},
"source": [
"We use `bert-embedding` package; see https://pypi.org/project/bert-embedding/\n",
"\n",
"We use GPU, so please choose the Colab kernel accordingly"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G7EZJcvygjTI",
"colab_type": "code",
"colab": {}
},
"source": [
"!pip install bert-embedding mxnet-cu100"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "bBbQwlK9gtsD",
"colab_type": "code",
"colab": {}
},
"source": [
"import mxnet as mx\n",
"from bert_embedding import BertEmbedding"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Aa3Lb7LfhBCU",
"colab_type": "code",
"outputId": "494c8c61-c92e-4e09-c205-404a903c3562",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 92
}
},
"source": [
"ctx = mx.gpu(0)\n",
"bert = BertEmbedding(ctx=ctx)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Vocab file is not found. Downloading.\n",
"Downloading /root/.mxnet/models/book_corpus_wiki_en_uncased-a6607397.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/vocab/book_corpus_wiki_en_uncased-a6607397.zip...\n",
"Downloading /root/.mxnet/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip...\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "uxAdRCkFmq5R",
"colab_type": "code",
"colab": {}
},
"source": [
"from tqdm.auto import tqdm, trange"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XrH9e6AIhXy5",
"colab_type": "code",
"colab": {}
},
"source": [
"bert_abstract = \"\"\"We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers.\n",
" Unlike recent language representation models, BERT is designed to pre-train deep bidirectional representations by jointly conditioning on both left and right context in all layers.\n",
" As a result, the pre-trained BERT representations can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications. \n",
"BERT is conceptually simple and empirically powerful. \n",
"It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE benchmark to 80.4% (7.6% absolute improvement), MultiNLI accuracy to 86.7 (5.6% absolute improvement) and the SQuAD v1.1 question answering Test F1 to 93.2 (1.5% absolute improvement), outperforming human performance by 2.0%.\"\"\""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "HM0FmfpphZP6",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 90
},
"outputId": "2f62b8e8-f0d3-426d-b18a-a4ddc1934bfa"
},
"source": [
"sentences = bert_abstract.split('\\n')\n",
"result = bert(sentences)\n",
"toks, embs = result[0]\n",
"print(toks)\n",
"print(len(toks), len(embs))\n",
"print(embs[0][:10])"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"['we', 'introduce', 'a', 'new', 'language', 'representation', 'model', 'called', 'bert', ',', 'which', 'stands', 'for', 'bidirectional', 'encoder', 'representations', 'from', 'transformers']\n",
"18 18\n",
"[ 0.4796474 0.18248817 -0.285975 -0.4656739 0.01248992 -0.07430518\n",
" -0.18017292 0.3781321 0.9135136 -0.25295877]\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jR-AMiRKoxOP",
"colab_type": "text"
},
"source": [
"# process a corpus"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_hJoYtvmZw2",
"colab_type": "text"
},
"source": [
"We download a 10k web-public .com corpus from https://wortschatz.uni-leipzig.de/en/download/\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "9d_fuzIOngqC",
"colab_type": "code",
"colab": {}
},
"source": [
"!wget http://pcai056.informatik.uni-leipzig.de/downloads/corpora/eng-com_web-public_2018_10K.tar.gz"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "zashEvfGnnnK",
"colab_type": "code",
"outputId": "eed70bd1-d411-403b-c3c2-d471c6dc079b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 181
}
},
"source": [
"!tar -xzvf eng-com_web-public_2018_10K.tar.gz"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"eng-com_web-public_2018_10K/\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-import.sql\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-inv_so.txt\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-words.txt\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-co_s.txt\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-inv_w.txt\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-co_n.txt\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-sources.txt\n",
"eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-sentences.txt\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RCWGTugrocOZ",
"colab_type": "code",
"colab": {}
},
"source": [
"with open('eng-com_web-public_2018_10K/eng-com_web-public_2018_10K-sentences.txt', 'r') as f:\n",
" lines = f.readlines()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "7QghFckJrumO",
"colab_type": "text"
},
"source": [
"remove row index from each sentence"
]
},
{
"cell_type": "code",
"metadata": {
"id": "nx-mTIeRriJ8",
"colab_type": "code",
"outputId": "17d3ba54-678e-4e85-e566-8509ffd3fd9a",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
}
},
"source": [
"lines[0]"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"'1\\t\"100 Butches Number 12\" moved to the West and was disconcerted to find that she didn\\'t draw lesbian attention in the way she had in Singapore.\\n'"
]
},
"metadata": {
"tags": []
},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "x7FTLEtgrlGE",
"colab_type": "code",
"colab": {}
},
"source": [
"all_sentences = [l.split('\\t')[1] for l in lines] "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "lOLT6c7yotcW",
"colab_type": "text"
},
"source": [
"# create a search index"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fZvELpMSsyer",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "code",
"metadata": {
"id": "B7MAHVbbsz3t",
"colab_type": "code",
"colab": {}
},
"source": [
"from sklearn.neighbors import KDTree\n",
"import numpy as np\n",
"\n",
"\n",
"class ContextNeighborStorage:\n",
" def __init__(self, sentences, model):\n",
" self.sentences = sentences\n",
" self.model = model\n",
"\n",
" def process_sentences(self):\n",
" result = self.model(self.sentences)\n",
"\n",
" self.sentence_ids = []\n",
" self.token_ids = []\n",
" self.all_tokens = []\n",
" all_embeddings = []\n",
" for i, (toks, embs) in enumerate(tqdm(result)):\n",
" for j, (tok, emb) in enumerate(zip(toks, embs)):\n",
" self.sentence_ids.append(i)\n",
" self.token_ids.append(j)\n",
" self.all_tokens.append(tok)\n",
" all_embeddings.append(emb)\n",
" all_embeddings = np.stack(all_embeddings)\n",
" # we normalize embeddings, so that euclidian distance is equivalent to cosine distance\n",
" self.normed_embeddings = (all_embeddings.T / (all_embeddings**2).sum(axis=1) ** 0.5).T\n",
"\n",
" def build_search_index(self):\n",
" # this takes some time\n",
" self.indexer = KDTree(self.normed_embeddings)\n",
"\n",
" def query(self, query_sent, query_word, k=10, filter_same_word=False):\n",
" toks, embs = self.model([query_sent])[0]\n",
"\n",
" found = False\n",
" for tok, emb in zip(toks, embs):\n",
" if tok == query_word:\n",
" found = True\n",
" break\n",
" if not found:\n",
" raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))\n",
" emb = emb / sum(emb**2)**0.5\n",
"\n",
" if filter_same_word:\n",
" initial_k = max(k, 100)\n",
" else:\n",
" initial_k = k\n",
" di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)\n",
" distances = []\n",
" neighbors = []\n",
" contexts = []\n",
" for i, index in enumerate(idx.ravel()):\n",
" token = self.all_tokens[index]\n",
" if filter_same_word and (query_word in token or token in query_word):\n",
" continue\n",
" distances.append(di.ravel()[i])\n",
" neighbors.append(token)\n",
" contexts.append(self.sentences[self.sentence_ids[index]])\n",
" if len(distances) == k:\n",
" break\n",
" return distances, neighbors, contexts"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "YIiBCBOauY6J",
"colab_type": "text"
},
"source": [
"Now let's use this indexer"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rB5NoR7AqJ6a",
"colab_type": "code",
"outputId": "597c26b0-3923-42bf-a740-9fa6f245df43",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 67,
"referenced_widgets": [
"861f120aa1514ac098d44215e7da52b7",
"122f54d0abf147b89d781aaad085cbbd",
"75061b1b409144b9948bf070976d9f61",
"7ecbed1b02b140b68d4e8c2302a433dd",
"a14b249a587843cf9a02fb6603aff741",
"b43e6fe4b52c40c18334bc8b6859641b",
"eb490fe2acb841b58be95a3eb32182ad",
"a72f7f4303c94be7ae91731bf7be332d"
]
}
},
"source": [
"storage = ContextNeighborStorage(sentences=all_sentences, model=bert)\n",
"storage.process_sentences()"
],
"execution_count": 26,
"outputs": [
{
"output_type": "display_data",
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "861f120aa1514ac098d44215e7da52b7",
"version_minor": 0,
"version_major": 2
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "j_sEJcvtzPs5",
"colab_type": "text"
},
"source": [
"Creating the index would require some time"
]
},
{
"cell_type": "code",
"metadata": {
"id": "gWcGzTNxuJm9",
"colab_type": "code",
"colab": {}
},
"source": [
"storage.build_search_index()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4LBv95lL_WMy",
"colab_type": "text"
},
"source": [
"# query homonymous words"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dLRC0Cliub2l",
"colab_type": "text"
},
"source": [
"Now see how it works: \n",
"\n",
"* if the word \"bank\" is in context of \"power bank\", then the nearest neighbor is a \"power bank\" as well.\n",
"\n",
"* if the word \"bank\" is in context of \"investment bank\", then the nearest neighbor is a \"bank\" as well, but in financial context\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "Siy9GwPdqYN1",
"colab_type": "code",
"outputId": "d98ee734-27c7-4485-fb79-5db2bc74c397",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
}
},
"source": [
"distances, neighbors, contexts = storage.query(query_sent='It is a power bank.', query_word='bank', k=5)\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 28,
"outputs": [
{
"output_type": "stream",
"text": [
"bank 0.7907381845089179 Finally, there’s a second version of the Duo that incorporates a 2000mAH power bank, the Flip Power World.\n",
"bank 0.8004909388259487 This simplifies the handling of new issues for the fund or the custodian bank, which benefits third-party banks and minimizes administrative costs.\n",
"bank 0.8120657920167197 The bank also was awarded a 5-star, Superior Bauer rating for Dec. 31, 2017, financial data.\n",
"bank 0.8334965657288435 Even though the cost ratio in UBS's investment bank looks stubbornly and dangerously high at 88%.\n",
"bank 0.8420417135949434 The points appear in the card member’s point bank within 24 hours.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XUyplIsorGcA",
"colab_type": "code",
"outputId": "6c1d8119-ecbe-43cb-e476-dc808cbe9964",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
}
},
"source": [
"distances, neighbors, contexts = storage.query(query_sent='It is an investment bank.', query_word='bank', k=5)\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"text": [
"bank 0.6642010406218649 The bank also was awarded a 5-star, Superior Bauer rating for Dec. 31, 2017, financial data.\n",
"bank 0.7402058801860062 Even though the cost ratio in UBS's investment bank looks stubbornly and dangerously high at 88%.\n",
"bank 0.7738409107605938 This simplifies the handling of new issues for the fund or the custodian bank, which benefits third-party banks and minimizes administrative costs.\n",
"banks 0.7774697527263753 This is an unusual fee, as many banks don’t charge you to move money in and out of your savings account.\n",
"bank 0.7801288580435606 Pop open your business bank account and take a look at the past few months of transactions.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "88A4TQdhvubC",
"colab_type": "text"
},
"source": [
"If we look for the neighbors not containing the word \"bank\", then with investment context it is all about finance, but for \"power bank\" there are a few non-financial contexts. \n",
"\n",
"Probably, with a larger corpus, we would be able to find even more relevant examples (like \"battery\")"
]
},
{
"cell_type": "code",
"metadata": {
"id": "doWA6h7ru1O-",
"colab_type": "code",
"outputId": "9d483a9f-2438-4fd0-f9d1-b87b17bd7782",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
}
},
"source": [
"distances, neighbors, contexts = storage.query(query_sent='It is an investment bank.', query_word='bank', k=5, filter_same_word=True)\n",
"total = 0\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))\n"
],
"execution_count": 30,
"outputs": [
{
"output_type": "stream",
"text": [
"finance 0.8551898041253515 Cahal is Vice Chairman of Deloitte UK and Chairman of the Advisory Corporate Finance business from 2014 (previously led the business from 2005).\n",
"financial 0.8562345307444398 Risk is involved in most financial ventures, but it's most relevant in discussions of insurance.\n",
"financial 0.8588390418789484 If 2007 is any indication, then 2008 is going to be a wild year for financial institutions facing a slew of risk management issues.\n",
"financial 0.8655535941427024 The bank also was awarded a 5-star, Superior Bauer rating for Dec. 31, 2017, financial data.\n",
"financial 0.8833212575995076 The Group applies the IFRS 9 - Financial Instruments simplified approach to measuring expected credit losses.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aa9fI8Q8u7qz",
"colab_type": "code",
"outputId": "821fd8f6-0599-4974-9fab-346f68e5c0c6",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
}
},
"source": [
"distances, neighbors, contexts = storage.query(query_sent='It is a power bank.', query_word='bank', k=5, filter_same_word=True)\n",
"total = 0\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": [
"bureau 0.9217660247116287 Is the company I am getting my FAKO score from affiliated with a credit bureau?\n",
"pump 0.9375881767755937 Let’s look at an old handle water pump for instance.\n",
"account 0.9440922848671971 If you can manage another credit account or you have good credit, getting another card might be beneficial.\n",
"center 0.9497044973940797 Tag: \"license\" in \"Data Center\"\n",
"company 0.9534303217383925 I used to work for a “feeder” company that collects debtor records at the county level, they buy this public record stuff and resell it as fact.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ctizE_Kt_mVs",
"colab_type": "text"
},
"source": [
"For a \"river bank\", there are no relevant examples in our small corpus (10k sentences only), but the result is still not completely meaningless (e.g. \"river side\" is related to \"river bank\")."
]
},
{
"cell_type": "code",
"metadata": {
"id": "2aIQGBJq_cVk",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
},
"outputId": "109a9fc5-566d-45cc-989a-7c01128bf784"
},
"source": [
"distances, neighbors, contexts = storage.query(query_sent='It is a river bank.', query_word='bank', k=5, filter_same_word=True)\n",
"total = 0\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 51,
"outputs": [
{
"output_type": "stream",
"text": [
"body 1.0063146046509068 Like a big stone, like a body of water, like a strong economy, however it was forged it seems that, once made, it has always been there.\n",
"side 1.010094269452912 Is there a way to fix a leak in the pressurized side of a pool system without digging it up?\n",
"loan 1.0141536943482097 The loan functions don't consider interest.\n",
"soil 1.0144778016304201 A lot of people think of soil a dirt, just something that you have to put bleach on your britches to get it cleaned out.\n",
"parallel 1.0160429339654162 “The nighttime is a parallel to the fog.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jmse5KnI_QxR",
"colab_type": "text"
},
"source": [
"# query named entities"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HlPavawt47Vx",
"colab_type": "text"
},
"source": [
"Now let's try a query with named entity. We can see that Amazon the company and Amazon the toponym have nearest neighbors which are a company and a toponym as well. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "vEEndzzI0s4P",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
},
"outputId": "4b575c24-39bb-486c-dcc0-3311481ed387"
},
"source": [
"s = \"Bezos announced that its two-day delivery service, Amazon Prime, had surpassed 100 million subscribers worldwide.\"\n",
"distances, neighbors, contexts = storage.query(query_sent=s, query_word='amazon', k=5)\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"text": [
"amazon 0.6034335621074037 Expanded third-party integration including Amazon Alexa, Google Assistant, and IFTTT.\n",
"amazon 0.6813754730639783 And fewer than 1 percent of e-commerce sales took place at Amazon, everyone’s favorite scapegoat for retail’s struggles.\n",
"amazon 0.6875736233539647 The Alexa Skills Kit (ASK) empowers anyone to leverage Amazon’s years of innovation in the field of voice design.\n",
"amazon 0.6956514512735816 The Republican tax overhaul that Trump signed last year could have singled out Amazon for harsh treatment, but the company still qualifies for every corporate tax break.\n",
"amazon 0.7096277774470231 One agency, for instance, is working with a real estate company and looking for data on people shopping on Amazon.com for moving boxes.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "Kgx-_aKp5AyN",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 128
},
"outputId": "6f1256e4-d839-46bf-e1b0-06542cd80d8c"
},
"source": [
"s = \"The Atlantic has sufficient wave and tidal energy to carry most of the Amazon's sediments out to sea, thus the river does not form a true delta\"\n",
"distances, neighbors, contexts = storage.query(query_sent=s, query_word='amazon', k=5)\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 46,
"outputs": [
{
"output_type": "stream",
"text": [
"brazil 0.9719889877743066 And, this year our stories are the work of traveling from Brazil’s Iguassu Falls to a chicken farm in Atlanta, building a 270-degree in-car VR experience, creating new partnerships to connect our audience’s passions with experiences, and much more.\n",
"amazon 0.9849940681911138 Describes how to create, manage, and use an Amazon CloudSearch domain to implement a search solution for your website or application.\n",
"amazon 0.9903185403586894 The Alexa Skills Kit (ASK) empowers anyone to leverage Amazon’s years of innovation in the field of voice design.\n",
"amazon 1.0300692791411952 Amazon’s dynamic pricing system is often a puzzle, but adding to the mystery is the fact that these items have different prices for different sizes but were all sold and shipped via Amazon — not through different sellers.\n",
"brazilian 1.0338752062964356 In one Brazilian examining the instances of hearing loss in bus drivers, it was found that 32.7% of bus drivers experienced noise-induced hearing loss.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a8BpB0He-faY",
"colab_type": "text"
},
"source": [
"Moreover, we can infer that Amazon the company is related to Google, Alexa and Netflix, whereas Amazon the river is related to Brazil and the Brazilian city Belem. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "xTv33CPs5C_g",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
},
"outputId": "4fd87e6f-70e6-48df-fcef-5eea045cda6e"
},
"source": [
"s = \"Bezos announced that its two-day delivery service, Amazon Prime, had surpassed 100 million subscribers worldwide.\"\n",
"distances, neighbors, contexts = storage.query(query_sent=s, query_word='amazon', k=5, filter_same_word=True)\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 47,
"outputs": [
{
"output_type": "stream",
"text": [
"google 0.8355246761580234 Expanded third-party integration including Amazon Alexa, Google Assistant, and IFTTT.\n",
"alexa 0.9301979249848525 Expanded third-party integration including Amazon Alexa, Google Assistant, and IFTTT.\n",
"netflix 0.9347305391903858 It isn't just on-demand video services like Netflix and Amazon that are taking customers away from cable.\n",
"google 0.9353963024942532 That month, Google sites were accessed by 185.05 million unique mobile users.\n",
"google 0.9362379311474001 Both Google and Amazon don’t place any limits, saying that the frontend load balancers will scale up as needed to support the traffic.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "kTBlcG9V-bQM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 128
},
"outputId": "4f91ab03-7ea4-45ca-b2a9-824d9758b3b8"
},
"source": [
"s = \"The Atlantic has sufficient wave and tidal energy to carry most of the Amazon's sediments out to sea, thus the river does not form a true delta\"\n",
"distances, neighbors, contexts = storage.query(query_sent=s, query_word='amazon', k=5, filter_same_word=True)\n",
"for d, w, c in zip(distances, neighbors, contexts):\n",
" print('{} {} {}'.format(w, d, c.strip()))"
],
"execution_count": 50,
"outputs": [
{
"output_type": "stream",
"text": [
"brazil 0.9719889877743066 And, this year our stories are the work of traveling from Brazil’s Iguassu Falls to a chicken farm in Atlanta, building a 270-degree in-car VR experience, creating new partnerships to connect our audience’s passions with experiences, and much more.\n",
"brazilian 1.0338752062964356 In one Brazilian examining the instances of hearing loss in bus drivers, it was found that 32.7% of bus drivers experienced noise-induced hearing loss.\n",
"brazil 1.0511793567879462 A subset of these action figures were also released in Canada, and years later, in Brazil.\n",
"brazil 1.0534603912133236 And with Brazil making a comeback — and AES having operations there — things are looking up for this stock on the growth side of equation.\n",
"belem 1.0627748582658 Start at the Embarcadero Belem dock to experience the waterways.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "I0lLzMKJ-qj1",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
@Shafi2016
Copy link

Okey thanks once again! I shall change, name of "date".Probably I will add the title to it as well. Then I can check a few manually to see if everything is going well.

@Shafi2016
Copy link

I checked it manually it is not working properly, the date and title do not match with para. Can you please illustrate it with the codes.?

@AsmaZbt
Copy link

AsmaZbt commented Sep 2, 2020

try this : change the output of the function query by adding the indexes of the neighbors :

def query(self, query_sent, query_word, k=10, filter_same_word=False):
toks, embs = self.model([query_sent])[0]

    found = False
    for tok, emb in zip(toks, embs):
        if tok == query_word:
            found = True
            break
    if not found:
        raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))
    emb = emb / sum(emb**2)**0.5

    if filter_same_word:
        initial_k = max(k, 100)
    else:
        initial_k = k
    di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
    distances = []
    neighbors = []
    contexts = []
    
    for i, index in enumerate(idx.ravel()):
        token = self.all_tokens[index]
        if filter_same_word and (query_word in token or token in query_word):
            continue
        distances.append(di.ravel()[i])
        neighbors.append(token)
        contexts.append(self.sentences[self.sentence_ids[index]])
        if len(distances) == k:
            break
    indexes =idx.ravel()
    return distances, neighbors, contexts,indexes

then in your function : you remose date from zip() and replace it by the indexes , then you iterate over the date list using the indexes
like this :

distances, neighbors, contexts,indexes = storage.query(query_sent='It is an investment bank.', query_word='bank', k=50)
dd = []
date = df["date"].tolist()
for d, w, c , idx in zip(distances, neighbors, contexts,indexes):
dd.append(
{'date':date[idx],
'neigh' : w,
'score':d,
'para' : c.strip()
})
ad =pd.DataFrame(dd)

@Shafi2016
Copy link

Thank you so much!. There is one small error-index
image
after running the last chunk of codes.

@Shafi2016
Copy link

Also, I tried to make it but the error is the same

distances, neighbors, contexts, indexes = storage.query(query_sent='It is an investment bank.', query_word='bank', k=50)
dd = []
#date = df["date"].tolist()
for d, w, c , idx in zip(distances, neighbors, contexts, indexes):
row_dic = df2.loc[df2.index==lines[idx]].to_dict()
dd.append(
{'date':row_dict["date"][ lines[idx]],
'neigh' : w,
'score': d,
'para' : c.strip()
})

ad = pd.DataFrame(dd)

@Shafi2016
Copy link

Hello AsmaZbt, I would really appreciate your input. the previous error is fixed but I'm getting a new error, I could not figure the issues despite spending many hours.
given as
image
after running these codes
distances, neighbors, contexts,indexes = storage.query(query_sent='It is an investment bank.', query_word='bank', k=5)
dd = []
#date = df["date"].astype(int)
date = df["date"].tolist()

for d, w, c , idx in zip(distances, neighbors, contexts,indexes):
dd.append(
{'date':date[idx],
'neigh' : w,
'score':d,
'para' : c.strip()
})

ad =pd.DataFrame(dd)

Here are the complete data processing steps:
df = pd.read_csv("/content/df3.csv")
df = df.set_index("content")
df.head(1)
image

text_dict = df.to_dict()
len_text = len(text_dict["date"])
df = df["date"].to_dict()
df_sentences_list = list(df.keys())
len(df_sentences_list)
df_sentences_list = [str(d) for d in tqdm(df_sentences_list)]
file_content = "\n".join(df_sentences_list)
with open("input_text.txt","w") as f:
f.write(file_content)

with open("/content/input_text.txt","r") as f:
lines1 = f.readlines()
image

Then I run all the required parts of the codes. And read the data again before running the final query codes
image
distances, neighbors, contexts,indexes = storage.query(query_sent='It is an investment bank.', query_word='bank', k=5)
dd = []
#date = df["date"].astype(int)
date = df["date"].tolist()

for d, w, c , idx in zip(distances, neighbors, contexts,indexes):
dd.append(
{'date':date[idx],
'neigh' : w,
'score':d,
'para' : c.strip()
})

ad =pd.DataFrame(dd)

But I get the error as given above.

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

hello @Shafi2016
could you please print idx ? to see the type of this variable

@Shafi2016
Copy link

Very strangely when using print(idx) inside loop I can not get anything out of it. However, when use the outside loop I get 10122.

@Shafi2016
Copy link

Also, I m getting the previous error as
image

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

could you please share your function query ?

@Shafi2016
Copy link

` def query(self, query_sent, query_word, k=10, filter_same_word=False):
toks, embs = self.model([query_sent])[0]
found = False
for tok, emb in zip(toks, embs):
if tok == query_word:
found = True
break
if not found:
raise ValueError('The query word {} is not a single token in sentence {}'.format(query_word, toks))
emb = emb / sum(emb**2)**0.5
if filter_same_word:
initial_k = max(k, 100)
else:
initial_k = k
di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
distances = []
neighbors = []
contexts = []

    for i, index in enumerate(idx.ravel()):
        token = self.all_tokens[index]
        if filter_same_word and (query_word in token or token in query_word):
           continue
        distances.append(di.ravel()[i])
        neighbors.append(token)
        contexts.append(self.sentences[self.sentence_ids[index]])                  
        #indexes =idx.ravel()
        if len(distances) == k:
            break
            indexes =idx.ravel()
    return distances, neighbors, contexts,indexes
    print(idx)`

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

indexes must be outside the loop for like this

for i, index in enumerate(idx.ravel()):
token = self.all_tokens[index]
if filter_same_word and (query_word in token or token in query_word):
continue
distances.append(di.ravel()[i])
neighbors.append(token)
contexts.append(self.sentences[self.sentence_ids[index]])
#indexes =idx.ravel()
if len(distances) == k:
break
indexes =idx.ravel() ########## indexes must but out the loop for
return distances, neighbors, contexts,indexes
print(idx)`

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

`for i, index in enumerate(idx.ravel()):

    token = self.all_tokens[index]
    if filter_same_word and (query_word in token or token in query_word):
       continue
    distances.append(di.ravel()[i])
    neighbors.append(token)
    contexts.append(self.sentences[self.sentence_ids[index]])                  
    #indexes =idx.ravel()
    if len(distances) == k:
        break
indexes =idx.ravel()
return distances, neighbors, contexts,indexes
print(idx)``

@Shafi2016
Copy link

I got a new error after incorporating the change.
image

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

you need to write it in the same indentation as return

you make a break this means stop so you can't do any thing more with this condition so you return

you just need to get the list of indexes there is no change no condition you can add it att the end or just after this instruction if you did not understand it well :

di, idx = self.indexer.query(emb.reshape(1, -1), k=initial_k)
indexes = idx.ravel()
distances = []

then you delete it from the end

or you do it at the end as I mentioned

@Shafi2016
Copy link

Thank you so much again. and sorry for bothering you. Now again index error is back.
image

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

not at all , you're welcome
try to print
len(date)
print(idx) inside the for

@Shafi2016
Copy link

Shafi2016 commented Sep 3, 2020

len(date) is 8804
print(idx) is 6554

image

image

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

the length of the liste date is 8804 and you have an index of 10759 that's why you have in out of index .
i need to think

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020

do you have a two file ? .txt and .csv ? then may be the length of list of sentences is not the same as your csv file.

i thought that your sentences are from the same CSV file so the length of date and para are the same
but now I'm sorry I can't help you

you need information how to link that sentences and dates ( for each sentence the appropriate date ) then you can solve the pb

@Shafi2016
Copy link

It is the same CSV file. I first convert, CSV to list and text as in the original example.
df = pd.read_csv("/content/df3.csv",parse_dates= True)
df = df.set_index("content")
df.head(1)
text_dict = df.to_dict()
len_text = len(text_dict["date"])
df = df["date"].to_dict()
df_sentences_list = list(df.keys())
len(df_sentences_list)
df_sentences_list = [str(d) for d in tqdm(df_sentences_list)]
file_content = "\n".join(df_sentences_list)
with open("input_text.txt","w") as f:
f.write(file_content)

with open("/content/input_text.txt","r") as f:
lines1 = f.readlines()
lines1[0]
all_sentences = [l.split('\t')[0] for l in lines1]

Again for the date, We use the same csv file, I only distable this part df = df.set_index("content")

@Shafi2016
Copy link

If I can have your email Id I will send refine codes with small sample data

@AsmaZbt
Copy link

AsmaZbt commented Sep 3, 2020 via email

@Shafi2016
Copy link

Thanks a lot!!!

@sridhardev07
Copy link

sridhardev07 commented Nov 1, 2021

@avidale Hi thanks for the amazing work.. I want to implement that, but need few points to understand..

  • I want it to work for large dataset, can it handle it or need to implement something else on it?
  • After getting embeddings, can I save it so next time I just load the data and get result by query?

@avidale
Copy link
Author

avidale commented Nov 1, 2021

@sridhardev07 yes and yes

@sridhardev07
Copy link

@sridhardev07 yes and yes

Can you tell me how? that will be really helpful for me!!

@sridhardev07
Copy link

Hi @avidale I tried this with some bigger dataset to test the accuracy. Dataset having sentences about 37126, it is showing me memory error: numpy.core._exceptions.MemoryError: Unable to allocate 2.35 GiB for an array with shape (819827, 768) and data type float32

I am having 16GB of RAM, can you tell any alternate way to do, which uses less RAM or retrieve the data from the disk??

@avidale
Copy link
Author

avidale commented Nov 10, 2021

Hi @sridhardev07!

The simplest trick I could suggest is to convert all vectors from float32 to float16, this will reduce memory requirements by half without significantly affecting the quality.

If this does not suffice, you could look at https://github.com/facebookresearch/faiss - a library for fast vector similarity search that allegedly can work with very large sets. Specifically, they implement product quantization for lossy compression of the vectors. If you choose to use Faiss, you should rewrite my solution: unite process_sentences and build_search_index that processes the sentences incrementally and adds their vectors to a faiss.IndexIVFPQ instead of a KDTree.

@sridhardev07
Copy link

Hi @avidale ! Thanks for the answer!

I tried converting the vectors to float16 it does help to reduce the size but not that much as I am working with large dataset.

I tried the second approach of Faiss, it worked good when I tried Flat index, so I can add the index incrementally. But on saving that to disk taking lots of storage. Approx 1 GB of 15K sentences. here is what I did:

 def __init__(self, sentences, model):
        self.sentences = sentences
        self.model = model
        self.index = faiss.IndexFlatL2(768)

    def process_sentences(self):
        result = self.model(self.sentences)
        self.sentence_ids = []
        self.token_ids = []
        self.all_tokens = []
        for i, (toks, embs) in enumerate(tqdm(result)):
            # initialize all_embeddings for every new sentence
            all_embeddings = []
            for j, (tok, emb) in enumerate(zip(toks, embs)):
                self.sentence_ids.append(i)
                self.token_ids.append(j)
                self.all_tokens.append(tok)
                all_embeddings.append(emb)

            all_embeddings = np.stack(all_embeddings) # Add embeddings after every sentence
            self.index.add(all_embeddings)

        faiss.write_index(self.index, "faiss_Model")

Then I tried with faiss.IndexIVFPQ, it works good, but did not works for incremental index as it needs the training data too. So need to calculate all the embeddings and then train and add. Again the size is small but its taking too much RAM that is causing issue while working with large data. here is what I did:

def __init__(self, sentences, model):
       self.sentences = sentences
       self.model = model
       self.quantizer = faiss.IndexFlatL2(768)
       self.index = faiss.IndexIVFPQ(self.quantizer, 768, 1000, 16, 8)

   def process_sentences(self):
       result = self.model(self.sentences)
       self.sentence_ids = []
       self.token_ids = []
       self.all_tokens = []
       all_embeddings = []
       for i, (toks, embs) in enumerate(tqdm(result)):
           for j, (tok, emb) in enumerate(zip(toks, embs)):
               self.sentence_ids.append(i)
               self.token_ids.append(j)
               self.all_tokens.append(tok)
               all_embeddings.append(emb)

       all_embeddings = np.stack(all_embeddings)
       self.index.train(all_embeddings) # Train
       self.index.add(all_embeddings) # Add to index
       faiss.write_index(self.index, "faiss_Model_mini")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment