Created
August 27, 2023 06:47
-
-
Save Spartee/1784b68c3ca1bbf0fed00d98623b551d to your computer and use it in GitHub Desktop.
Arxiv search for papers with Redis and Langchain==0.0.272 (master)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import openai | |
from typing import List, Tuple | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
from langchain.document_loaders import ArxivLoader | |
from langchain.docstore.document import Document | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores.redis import Redis, RedisTag, RedisText | |
EMBEDDINGS = OpenAIEmbeddings() | |
def get_arxiv_docs(query, num_docs=10): | |
loader = ArxivLoader(query, load_max_docs=num_docs, load_all_available_meta=True) | |
documents = loader.load() | |
texts = [d.page_content for d in documents] | |
metadatas = [d.metadata for d in documents] | |
return texts, metadatas | |
def run_search( | |
search_query, | |
content_query, | |
num_papers=10, | |
num_results=5, | |
redis_url="redis://localhost:6379" | |
) -> List[Tuple[Document, float]]: | |
"""Run a search on arxiv and then semantic search on the results | |
Args: | |
search_query (str): The query to search arxiv for | |
content_query (str): The query to search the content of the papers for | |
num_papers (int, optional): The number of papers to search. Defaults to 10. | |
num_results (int, optional): The number of results to return. Defaults to 5. | |
redis_url (str, optional): The redis url to connect to. Defaults to "redis://localhost:6379". | |
""" | |
texts, metadata = get_arxiv_docs(search_query) | |
rds, keys = Redis.from_texts_return_keys( | |
texts, | |
EMBEDDINGS, | |
metadatas=metadata, | |
redis_url=redis_url, | |
index_name="arxiv", | |
) | |
# make filter | |
has_cslg_category = RedisTag("categories") == "cs.LG" | |
published_in_2023 = RedisText("published_first_time") % "2023*" | |
cslg_in_2023 = has_cslg_category & published_in_2023 | |
docs = rds.similarity_search_with_relevance_scores( | |
content_query, | |
k=num_results, | |
filter=cslg_in_2023, | |
) | |
return docs | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--search_query", type=str, default="Retrieval Augmented Generation") | |
parser.add_argument("--content_query", type=str, default="Implementing context retrieval for chatbots") | |
parser.add_argument("--num_papers", type=int, default=10) | |
parser.add_argument("--num_results", type=int, default=5) | |
parser.add_argument("--redis_url", type=str, default="redis://localhost:6379") | |
args = parser.parse_args() | |
docs = run_search( | |
args.search_query, | |
args.content_query, | |
args.num_papers, | |
args.num_results, | |
args.redis_url, | |
) | |
for out, score in docs: | |
meta = out.metadata | |
print("Title: ", meta["Title"]) | |
print("URL: ", meta["links"]) | |
print("Score: ", score) | |
print("Published: ", meta["published_first_time"]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment