Skip to content

Instantly share code, notes, and snippets.

@stephenleo
Last active March 22, 2023 21:50
Show Gist options
  • Save stephenleo/22d9bd76e37ed039f31aa7420d7558b2 to your computer and use it in GitHub Desktop.
Save stephenleo/22d9bd76e37ed039f31aa7420d7558b2 to your computer and use it in GitHub Desktop.
[Medium] ArXiv Semantic Paper Search

ArXiv Semantic Paper Search

  • Code from the Medium post Link
# Create the necessary directories
mkdir -p semantic_similarity/notebooks semantic_similarity/data semantic_similarity/milvus
# CD into the data directory
cd semantic_similarity/data
# Create and activate a conda environment
conda create -n semantic_similarity python=3.9
conda activate semantic_similarity
## Create Virtual Environment using venv if not using conda
# python -m venv semantic_similarity
# source semantic_similarity/bin/activate
# Pip install the necessary libraries
pip install jupyterlab kaggle matplotlib scikit-learn tqdm ipywidgets
pip install "dask[complete]" sentence-transformers
pip install pandas pyarrow pymilvus protobuf==3.20.0
# Download data using the kaggle API
kaggle datasets download -d Cornell-University/arxiv
# Unzip the data into the local directory
unzip arxiv.zip
# Delete the Zip file
rm arxiv.zip
import dask.bag as db
import json
from datetime import datetime
import time
data_path = '../data/arxiv-metadata-oai-snapshot.json'
# Read the file in blocks of 10MB and parse the JSON.
papers_db = db.read_text(data_path, blocksize="10MB").map(json.loads)
# Print the first row
papers_db.take(1)
def v1_date(row):
"""
For each row in the dask bag,
find the date of the first version of the paper
and add it to the row as a new column
Args:
row: a row of the dask bag
Returns:
A row of the dask bag with added "unix_time" column
"""
versions = row["versions"]
date = None
for version in versions:
if version["version"] == "v1":
date = datetime.strptime(version["created"], "%a, %d %b %Y %H:%M:%S %Z")
date = int(time.mktime(date.timetuple()))
row["unix_time"] = date
return row
def text_col(row):
"""
It takes a row of a dataframe, adds a new column called 'text'
that is the concatenation of the 'title' and 'abstract' columns
Args:
row: the row of the dataframe
Returns:
A row with the text column added.
"""
row["text"] = row["title"] + "[SEP]" + row["abstract"]
return row
def filters(row):
"""
For each row in the dask bag, only keep the row if it meets the filter criteria
Args:
row: the row of the dataframe
Returns:
Boolean mask
"""
return ((len(row["id"])<16) and
(len(row["categories"])<200) and
(len(row["title"])<4096) and
(len(row["abstract"])<65535) and
("cs." in row["categories"]) # Keep only CS papers
)
# Specify columns to keep in the final table
cols_to_keep = ["id", "categories", "title", "abstract", "unix_time", "text"]
# Apply the pre-processing
papers_db = (
papers_db.map(lambda row: v1_date(row))
.map(lambda row: text_col(row))
.map(
lambda row: {
key: value
for key, value in row.items()
if key in cols_to_keep
}
)
.filter(filters)
)
# Print the first row
papers_db.take(1)
# Convert the Dask Bag to a Dask Dataframe
schema = {
"id": str,
"title": str,
"categories": str,
"abstract": str,
"unix_time": int,
"text": str,
}
papers_df = papers_db.to_dataframe(meta=schema)
# Display first 5 rows
papers_df.head()
# CD into milvus directory
cd semantic_similarity/milvus
# Download the Standalone version of Milvus docker compose
wget https://github.com/milvus-io/milvus/releases/download/v2.1.0/milvus-standalone-docker-compose.yml -O ./docker-compose.yml
# Run the Milvus server docker container on your local
sudo docker-compose up -d
# Make sure a Milvus server is already running
from pymilvus import connections, utility
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
# Connect to Milvus server
connections.connect(alias="default", host="localhost", port="19530")
# Collection name
collection_name = "arxiv"
# Embedding size
emb_dim = 768
# # Check for existing collection and drop if exists
# if utility.has_collection(collection_name):
# print(utility.list_collections())
# utility.drop_collection(collection_name)
# Create a schema for the collection
idx = FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=16)
categories = FieldSchema(name="categories", dtype=DataType.VARCHAR, max_length=200)
title = FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=4096)
abstract = FieldSchema(name="abstract", dtype=DataType.VARCHAR, max_length=65535)
unix_time = FieldSchema(name="unix_time", dtype=DataType.INT64)
embedding = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=emb_dim)
# Fields in the collection
fields = [idx, categories, title, abstract, unix_time, embedding]
schema = CollectionSchema(
fields=fields, description="Semantic Similarity of Scientific Papers"
)
# Create a collection with the schema
collection = Collection(
name=collection_name, schema=schema, using="default", shards_num=10
)
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
# Scientific Papers SBERT Model
model = SentenceTransformer('allenai-specter')
def emb_gen(partition):
return model.encode(partition['text']).tolist()
# Initialize
collection = Collection(collection_name)
for partition in tqdm(range(papers_df.npartitions)):
# Get the dask dataframe for the partition
subset_df = papers_df.get_partition(partition)
# Check if dataframe is empty
if len(subset_df.index) != 0:
# Metadata
data = [
subset_df[col].values.compute().tolist()
for col in ["id", "categories", "title", "abstract", "unix_time"]
]
# Embeddings
data += [
subset_df
.map_partitions(emb_gen)
.compute()[0]
]
# Insert data
collection.insert(data)
# Add an ANN index to the collection
index_params = {
"metric_type": "L2",
"index_type": "HNSW",
"params": {"efConstruction": 128, "M": 8},
}
collection.create_index(field_name="embedding", index_params=index_params)
# Load the collection into memory
collection = Collection(collection_name)
collection.load()
def query_and_display(query_text, collection, num_results=10):
# Embed the Query Text
query_emb = [model.encode(query_text)]
# Search Params
search_params = {"metric_type": "L2", "params": {"ef": 128}}
# Search
query_start = datetime.now()
results = collection.search(
data=query_emb,
anns_field="embedding",
param=search_params,
limit=num_results,
expr=None,
output_fields=["title", "abstract"],
)
query_end = datetime.now()
# Print Results
print(f"Query Speed: {(query_end - query_start).total_seconds():.2f} s")
print("Results:")
for res in results[0]:
title = res.entity.get("title").replace("\n ", "")
print(f"➡️ ID: {res.id}. L2 Distance: {res.distance:.2f}")
print(f"Title: {title}")
print(f"Abstract: {res.entity.get('abstract')}")
# Query for papers that are similar to the SimCSE paper
title = "SimCSE: Simple Contrastive Learning of Sentence Embeddings"
abstract = """This paper presents SimCSE, a simple contrastive learning framework that greatly advances state-of-the-art sentence embeddings. We first describe an unsupervised approach, which takes an input sentence and predicts itself in a contrastive objective, with only standard dropout used as noise. This simple method works surprisingly well, performing on par with previous supervised counterparts. We find that dropout acts as minimal data augmentation, and removing it leads to a representation collapse. Then, we propose a supervised approach, which incorporates annotated pairs from natural language inference datasets into our contrastive learning framework by using "entailment" pairs as positives and "contradiction" pairs as hard negatives. We evaluate SimCSE on standard semantic textual similarity (STS) tasks, and our unsupervised and supervised models using BERT base achieve an average of 76.3% and 81.6% Spearman's correlation respectively, a 4.2% and 2.2% improvement compared to the previous best results. We also show -- both theoretically and empirically -- that the contrastive learning objective regularizes pre-trained embeddings' anisotropic space to be more uniform, and it better aligns positive pairs when supervised signals are available."""
query_text = f"{title}[SEP]{abstract}"
query_and_display(query_text, collection, num_results=10)
# Release the collection from memory when it's not needed anymore
collection.release()
# CD into milvus directory
cd semantic_similarity/milvus
# Shut down milvus
sudo docker-compose down
# Delete the files
sudo rm -rf volumes/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment