Skip to content

Instantly share code, notes, and snippets.

@jamescalam
Last active March 25, 2022 09:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamescalam/9b84408d6c7f1fe4bf7eda2ab410c086 to your computer and use it in GitHub Desktop.
Save jamescalam/9b84408d6c7f1fe4bf7eda2ab410c086 to your computer and use it in GitHub Desktop.
02_negative_mining.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jamescalam/9b84408d6c7f1fe4bf7eda2ab410c086/02_negative_mining.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ERvD83pWfTa7"
},
"source": [
"To perform the negative mining step we must create a vector database to store encoded passages, and allow us to search for similar passages that do not match the query we're searching with. This requires two things:\n",
"\n",
"* a pre-existing retriever model to build encodings - for this we will use a model from the *sentence-transformers* library\n",
"* a vector DB to store encodings - for this we will use Pinecone as it is an free and easy vector DB to setup, which is fast at scale\n",
"\n",
"Let's load the model first."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LieNdfkPfTa9",
"outputId": "c193b0fd-2c6d-432f-b2da-0df396096225"
},
"outputs": [
{
"data": {
"text/plain": [
"SentenceTransformer(\n",
" (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: DistilBertModel \n",
" (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
")"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
"model.max_seq_length = 256\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WW0hLoXDfTa-"
},
"source": [
"And now initialize a Pinecone index for storing the encode passage vectors later."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4s3QS9-EfTa_"
},
"outputs": [],
"source": [
"import pinecone\n",
"\n",
"with open('secret', 'r') as fp:\n",
" API_KEY = fp.read() # get api key app.pinecone.io\n",
"\n",
"pinecone.init(\n",
" api_key=API_KEY,\n",
" environment='us-west1-gcp'\n",
")\n",
"# create a new genq index if does not already exist\n",
"if 'negative-mine' not in pinecone.list_indexes():\n",
" pinecone.create_index(\n",
" 'negative-mine',\n",
" dimension=model.get_sentence_embedding_dimension(),\n",
" metric='dotproduct',\n",
" pods=1 # increase for faster mining\n",
" )\n",
"# connect\n",
"index = pinecone.Index('negative-mine')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "am0GBKFYfTa_"
},
"source": [
"Now we encode the passages and store in the `negative-mine` index. We can create a generator function for loading the passages, we will include the relevant queries for each passage in their metadata, so we can avoid returning the true passage pairs with Pinecone metadata filtering."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SqjSKb1sfTa_"
},
"outputs": [],
"source": [
"from tqdm.auto import tqdm\n",
"\n",
"def get_text():\n",
" with open('data/pairs.tsv', 'r', encoding='utf-8') as fp:\n",
" lines = fp.read().split('\\n')\n",
" for line in tqdm(lines):\n",
" try:\n",
" query, passage = line.split('\\t')\n",
" yield query, passage\n",
" except ValueError:\n",
" # in case of malformed data, pass onto next row\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bIqDTJkMfTbA",
"outputId": "19962868-17d6-4e07-fe38-b7cc7f99d8e0",
"colab": {
"referenced_widgets": [
"f123d57309b042eca4ce279ec0aff06e"
]
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f123d57309b042eca4ce279ec0aff06e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/200 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"{'dimension': 768,\n",
" 'index_fullness': 0.0,\n",
" 'namespaces': {'': {'vector_count': 67840}}}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pair_gen = get_text()\n",
"\n",
"pairs = []\n",
"to_upsert = []\n",
"passage_batch = []\n",
"id_batch = []\n",
"batch_size = 64\n",
"\n",
"for i, (query, passage) in enumerate(pair_gen):\n",
" pairs.append((query, passage))\n",
" # we do this to avoid passage duplication in the vector DB\n",
" if passage not in passage_batch: \n",
" passage_batch.append(passage)\n",
" id_batch.append(str(i))\n",
" # on reaching batch_size, we encode and upsert\n",
" if len(passage_batch) == batch_size:\n",
" embeds = model.encode(passage_batch).tolist()\n",
" # upload to index\n",
" index.upsert(vectors=list(zip(id_batch, embeds)))\n",
" # refresh batches\n",
" passage_batch = []\n",
" id_batch = []\n",
" \n",
"# check number of vectors in the index\n",
"index.describe_index_stats()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HmUAf2LSfTbA"
},
"source": [
"The database is setup for us to begin the *negative mining* step. We will loop through each query in `queries`, returning *10* of the most similar passages that *do not* share the same query in their `queries` metadata."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CWDU5WgkfTbA",
"outputId": "9194daf1-11f5-4ab9-c3fb-14b05bf3666b",
"colab": {
"referenced_widgets": [
"33e05700fb8a43daadf39f5c2f2166d5"
]
}
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "33e05700fb8a43daadf39f5c2f2166d5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/2000 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import random\n",
"\n",
"batch_size = 100\n",
"triplets = []\n",
"\n",
"for i in tqdm(range(0, len(pairs), batch_size)):\n",
" # embed queries and query pinecone in batches to minimize network latency\n",
" i_end = min(i+batch_size, len(pairs))\n",
" queries = [pair[0] for pair in pairs[i:i_end]]\n",
" pos_passages = [pair[1] for pair in pairs[i:i_end]]\n",
" # create query embeddings\n",
" query_embs = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)\n",
" # search for top_k most similar passages\n",
" res = index.query(query_embs.tolist(), top_k=10)\n",
" # iterate through queries and find negatives\n",
" for query, pos_passage, query_res in zip(queries, pos_passages, res['results']):\n",
" top_results = query_res['matches']\n",
" # shuffle results so they are in random order\n",
" random.shuffle(top_results)\n",
" for hit in top_results:\n",
" neg_passage = pairs[int(hit['id'])][1]\n",
" # check that we're not just returning the positive passage\n",
" if neg_passage != pos_passage:\n",
" # if not we can add this to our (Q, P+, P-) triplets\n",
" triplets.append(query+'\\t'+pos_passage+'\\t'+neg_passage)\n",
" break\n",
"\n",
"with open('data/triplets.tsv', 'w', encoding='utf-8') as fp:\n",
" fp.write('\\n'.join(triplets))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Z2EGl2_4fTbB"
},
"outputs": [],
"source": [
"pinecone.delete_index('negative-mine') # delete the index when done to avoid higher charges (if using multiple pods)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nJvtpjmIfTbB"
},
"source": [
"With that we now have even more *(query, passage) pairs*, that are both positive and negative matches. The next step in GPL will see us scoring all of these pairs using a cross-encoder model."
]
}
],
"metadata": {
"environment": {
"kernel": "python3",
"name": "common-cu110.m91",
"type": "gcloud",
"uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
},
"interpreter": {
"hash": "5188bc372fa413aa2565ae5d28228f50ad7b2c4ebb4a82c5900fd598adbb6408"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
},
"colab": {
"name": "02_negative_mining.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment