Skip to content

Instantly share code, notes, and snippets.

@radekosmulski
Created August 9, 2021 15:09
Show Gist options
  • Save radekosmulski/cdeeea80596012946447c7a5acf3f93d to your computer and use it in GitHub Desktop.
Save radekosmulski/cdeeea80596012946447c7a5acf3f93d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 39,
"id": "d97446de",
"metadata": {},
"outputs": [],
"source": [
"from sentence_transformers import SentenceTransformer, util\n",
"model = SentenceTransformer('paraphrase-MiniLM-L6-v2')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "fc60dea7",
"metadata": {},
"outputs": [],
"source": [
"import io\n",
"import librosa\n",
"from time import time\n",
"import numpy as np\n",
"import IPython.display as ipd\n",
"import grpc\n",
"import requests\n",
"\n",
"import riva_api.riva_nlp_pb2 as rnlp\n",
"import riva_api.riva_nlp_pb2_grpc as rnlp_srv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caa6d560",
"metadata": {},
"outputs": [],
"source": [
"!wget \"https://raw.githubusercontent.com/amephraim/nlp/master/texts/J.%20K.%20Rowling%20-%20Harry%20Potter%201%20-%20Sorcerer's%20Stone.txt\""
]
},
{
"cell_type": "markdown",
"id": "27ca5659",
"metadata": {},
"source": [
"Let's read in the text of the book and process it into a list of paragraphs."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b9ce423a",
"metadata": {},
"outputs": [],
"source": [
"with open(\"J. K. Rowling - Harry Potter 1 - Sorcerer's Stone.txt\") as file:\n",
" lines = file.readlines()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1013c8f9",
"metadata": {},
"outputs": [],
"source": [
"paragraphs = []\n",
"paragraph = ''\n",
"for line in lines:\n",
" if line == '\\n':\n",
" if len(paragraph) > 100: paragraphs.append(paragraph)\n",
" paragraph = ''\n",
" else:\n",
" paragraph += line.rstrip()\n",
" paragraph += ' '"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "736f3347",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1504"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(paragraphs)"
]
},
{
"cell_type": "markdown",
"id": "5d6d028a",
"metadata": {},
"source": [
"Now that we have the paragraphs, let's embed them!"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "f1f09eeb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1.28 s, sys: 115 ms, total: 1.39 s\n",
"Wall time: 506 ms\n"
]
}
],
"source": [
"%%time\n",
"\n",
"paragraph_embeddings = model.encode(paragraphs)"
]
},
{
"cell_type": "markdown",
"id": "ed2403c8",
"metadata": {},
"source": [
"Not too bad, given that it was an entire book!\n",
"\n",
"Now let's write the code that will take in a question and generate an answer."
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "c4d0d93b",
"metadata": {},
"outputs": [],
"source": [
"def answer_question(question):\n",
" '''\n",
" This function takes in a question, embeds it and looks for paragraphs that would be semantically\n",
" most similar. It then takes top 10 of such paragraphs and concatenates them to create a context.\n",
" \n",
" We then ship the context over to Riva, to the Triton inference server and print the output\n",
" '''\n",
" query_embedding = model.encode(question)\n",
" similarities = util.pytorch_cos_sim(query_embedding, paragraph_embeddings)\n",
"\n",
" context = ''\n",
" for idx in similarities.argsort()[0].flip(0)[:10]:\n",
" context += paragraphs[idx]\n",
"\n",
" channel = grpc.insecure_channel('localhost:50051')\n",
" riva_nlp = rnlp_srv.RivaLanguageUnderstandingStub(channel)\n",
" req = rnlp.NaturalQueryRequest()\n",
" req.query = question\n",
" req.context = context\n",
" resp = riva_nlp.NaturalQuery(req)\n",
"\n",
" print(f\"Query: {question}\")\n",
" print(f\"Answer: {resp.results[0].answer}\")"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "11d0b544",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Query: What was the name of Harry Potter's owl?\n",
"Answer: Hedwig,\n"
]
}
],
"source": [
"answer_question(\"What was the name of Harry Potter's owl?\")"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "0d645de9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Query: Does the wand choose the wizard or the wizard chooses the wand?\n",
"Answer: The wand chooses the wizard,\n"
]
}
],
"source": [
"answer_question(\"Does the wand choose the wizard or the wizard chooses the wand?\")"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "24eb2e9f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Query: Who was Harry Potter's best friend?\n",
"Answer: \n"
]
}
],
"source": [
"answer_question(\"Who was Harry Potter's best friend?\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment