"cell_type": "markdown",
"metadata": {
"id": "ERvD83pWfTa7"
"source": [
"To perform the negative mining step we must create a vector database to store encoded passages, and allow us to search for similar passages that do not match the query we're searching with. This requires two things:\n",
"* a pre-existing retriever model to build encodings - for this we will use a model from the *sentence-transformers* library\n",
"* a vector DB to store encodings - for this we will use Pinecone as it is an free and easy vector DB to setup, which is fast at scale\n",
"Let's load the model first."
"source": [
"from sentence_transformers import SentenceTransformer\n",
"model = SentenceTransformer('msmarco-distilbert-base-tas-b')\n",
"model.max_seq_length = 256\n",
"cell_type": "markdown",
"metadata": {
"id": "WW0hLoXDfTa-"
"source": [
"And now initialize a Pinecone index for storing the encode passage vectors later."
"source": [
"import pinecone\n",
"with open('secret', 'r') as fp:\n",
" API_KEY = # get api key\n",
" api_key=API_KEY,\n",
" environment='us-west1-gcp'\n",
"# create a new genq index if does not already exist\n",
"if 'negative-mine' not in pinecone.list_indexes():\n",
" pinecone.create_index(\n",
" 'negative-mine',\n",
" dimension=model.get_sentence_embedding_dimension(),\n",
" metric='dotproduct',\n",
" pods=1 # increase for faster mining\n",
" )\n",
"# connect\n",
"index = pinecone.Index('negative-mine')"
"cell_type": "markdown",
"metadata": {
"id": "am0GBKFYfTa_"
"source": [
"Now we encode the passages and store in the `negative-mine` index. We can create a generator function for loading the passages, we will include the relevant queries for each passage in their metadata, so we can avoid returning the true passage pairs with Pinecone metadata filtering."
"source": [
"from import tqdm\n",
"def get_text():\n",
" with open('data/pairs.tsv', 'r', encoding='utf-8') as fp:\n",
" lines ='\\n')\n",
" for line in tqdm(lines):\n",
" try:\n",
" query, passage = line.split('\\t')\n",
" yield query, passage\n",
" except ValueError:\n",
" # in case of malformed data, pass onto next row\n",
" pass"
"source": [
"pair_gen = get_text()\n",
"pairs = []\n",
"to_upsert = []\n",
"passage_batch = []\n",
"id_batch = []\n",
"batch_size = 64\n",
"for i, (query, passage) in enumerate(pair_gen):\n",
" pairs.append((query, passage))\n",
" # we do this to avoid passage duplication in the vector DB\n",
" if passage not in passage_batch: \n",
" passage_batch.append(passage)\n",
" id_batch.append(str(i))\n",
" # on reaching batch_size, we encode and upsert\n",
" if len(passage_batch) == batch_size:\n",
" embeds = model.encode(passage_batch).tolist()\n",
" # upload to index\n",
" index.upsert(vectors=list(zip(id_batch, embeds)))\n",
" # refresh batches\n",
" passage_batch = []\n",
" id_batch = []\n",
" \n",
"# check number of vectors in the index\n",
"cell_type": "markdown",
"metadata": {
"id": "HmUAf2LSfTbA"
"source": [
"The database is setup for us to begin the *negative mining* step. We will loop through each query in `queries`, returning *10* of the most similar passages that *do not* share the same query in their `queries` metadata."
"source": [
"import random\n",
"batch_size = 100\n",
"triplets = []\n",
"for i in tqdm(range(0, len(pairs), batch_size)):\n",
" # embed queries and query pinecone in batches to minimize network latency\n",
" i_end = min(i+batch_size, len(pairs))\n",
" queries = [pair[0] for pair in pairs[i:i_end]]\n",
" pos_passages = [pair[1] for pair in pairs[i:i_end]]\n",
" # create query embeddings\n",
" query_embs = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)\n",
" # search for top_k most similar passages\n",
" res = index.query(query_embs.tolist(), top_k=10)\n",
" # iterate through queries and find negatives\n",
" for query, pos_passage, query_res in zip(queries, pos_passages, res['results']):\n",
" top_results = query_res['matches']\n",
" # shuffle results so they are in random order\n",
" random.shuffle(top_results)\n",
" for hit in top_results:\n",
" neg_passage = pairs[int(hit['id'])][1]\n",
" # check that we're not just returning the positive passage\n",
" if neg_passage != pos_passage:\n",
" # if not we can add this to our (Q, P+, P-) triplets\n",
" triplets.append(query+'\\t'+pos_passage+'\\t'+neg_passage)\n",
" break\n",
"with open('data/triplets.tsv', 'w', encoding='utf-8') as fp:\n",
" fp.write('\\n'.join(triplets))"
"cell_type": "markdown",
"metadata": {
"id": "nJvtpjmIfTbB"
"source": [
"With that we now have even more *(query, passage) pairs*, that are both positive and negative matches. The next step in GPL will see us scoring all of these pairs using a cross-encoder model."
