- Code from the Medium post Link
Last active
March 22, 2023 21:50
-
-
Save stephenleo/22d9bd76e37ed039f31aa7420d7558b2 to your computer and use it in GitHub Desktop.
[Medium] ArXiv Semantic Paper Search
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create the necessary directories | |
mkdir -p semantic_similarity/notebooks semantic_similarity/data semantic_similarity/milvus | |
# CD into the data directory | |
cd semantic_similarity/data | |
# Create and activate a conda environment | |
conda create -n semantic_similarity python=3.9 | |
conda activate semantic_similarity | |
## Create Virtual Environment using venv if not using conda | |
# python -m venv semantic_similarity | |
# source semantic_similarity/bin/activate | |
# Pip install the necessary libraries | |
pip install jupyterlab kaggle matplotlib scikit-learn tqdm ipywidgets | |
pip install "dask[complete]" sentence-transformers | |
pip install pandas pyarrow pymilvus protobuf==3.20.0 | |
# Download data using the kaggle API | |
kaggle datasets download -d Cornell-University/arxiv | |
# Unzip the data into the local directory | |
unzip arxiv.zip | |
# Delete the Zip file | |
rm arxiv.zip |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dask.bag as db | |
import json | |
from datetime import datetime | |
import time | |
data_path = '../data/arxiv-metadata-oai-snapshot.json' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Read the file in blocks of 10MB and parse the JSON. | |
papers_db = db.read_text(data_path, blocksize="10MB").map(json.loads) | |
# Print the first row | |
papers_db.take(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def v1_date(row): | |
""" | |
For each row in the dask bag, | |
find the date of the first version of the paper | |
and add it to the row as a new column | |
Args: | |
row: a row of the dask bag | |
Returns: | |
A row of the dask bag with added "unix_time" column | |
""" | |
versions = row["versions"] | |
date = None | |
for version in versions: | |
if version["version"] == "v1": | |
date = datetime.strptime(version["created"], "%a, %d %b %Y %H:%M:%S %Z") | |
date = int(time.mktime(date.timetuple())) | |
row["unix_time"] = date | |
return row | |
def text_col(row): | |
""" | |
It takes a row of a dataframe, adds a new column called 'text' | |
that is the concatenation of the 'title' and 'abstract' columns | |
Args: | |
row: the row of the dataframe | |
Returns: | |
A row with the text column added. | |
""" | |
row["text"] = row["title"] + "[SEP]" + row["abstract"] | |
return row | |
def filters(row): | |
""" | |
For each row in the dask bag, only keep the row if it meets the filter criteria | |
Args: | |
row: the row of the dataframe | |
Returns: | |
Boolean mask | |
""" | |
return ((len(row["id"])<16) and | |
(len(row["categories"])<200) and | |
(len(row["title"])<4096) and | |
(len(row["abstract"])<65535) and | |
("cs." in row["categories"]) # Keep only CS papers | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Specify columns to keep in the final table | |
cols_to_keep = ["id", "categories", "title", "abstract", "unix_time", "text"] | |
# Apply the pre-processing | |
papers_db = ( | |
papers_db.map(lambda row: v1_date(row)) | |
.map(lambda row: text_col(row)) | |
.map( | |
lambda row: { | |
key: value | |
for key, value in row.items() | |
if key in cols_to_keep | |
} | |
) | |
.filter(filters) | |
) | |
# Print the first row | |
papers_db.take(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Convert the Dask Bag to a Dask Dataframe | |
schema = { | |
"id": str, | |
"title": str, | |
"categories": str, | |
"abstract": str, | |
"unix_time": int, | |
"text": str, | |
} | |
papers_df = papers_db.to_dataframe(meta=schema) | |
# Display first 5 rows | |
papers_df.head() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CD into milvus directory | |
cd semantic_similarity/milvus | |
# Download the Standalone version of Milvus docker compose | |
wget https://github.com/milvus-io/milvus/releases/download/v2.1.0/milvus-standalone-docker-compose.yml -O ./docker-compose.yml | |
# Run the Milvus server docker container on your local | |
sudo docker-compose up -d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Make sure a Milvus server is already running | |
from pymilvus import connections, utility | |
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType | |
# Connect to Milvus server | |
connections.connect(alias="default", host="localhost", port="19530") | |
# Collection name | |
collection_name = "arxiv" | |
# Embedding size | |
emb_dim = 768 | |
# # Check for existing collection and drop if exists | |
# if utility.has_collection(collection_name): | |
# print(utility.list_collections()) | |
# utility.drop_collection(collection_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a schema for the collection | |
idx = FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=16) | |
categories = FieldSchema(name="categories", dtype=DataType.VARCHAR, max_length=200) | |
title = FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=4096) | |
abstract = FieldSchema(name="abstract", dtype=DataType.VARCHAR, max_length=65535) | |
unix_time = FieldSchema(name="unix_time", dtype=DataType.INT64) | |
embedding = FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=emb_dim) | |
# Fields in the collection | |
fields = [idx, categories, title, abstract, unix_time, embedding] | |
schema = CollectionSchema( | |
fields=fields, description="Semantic Similarity of Scientific Papers" | |
) | |
# Create a collection with the schema | |
collection = Collection( | |
name=collection_name, schema=schema, using="default", shards_num=10 | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sentence_transformers import SentenceTransformer | |
from tqdm import tqdm | |
# Scientific Papers SBERT Model | |
model = SentenceTransformer('allenai-specter') | |
def emb_gen(partition): | |
return model.encode(partition['text']).tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initialize | |
collection = Collection(collection_name) | |
for partition in tqdm(range(papers_df.npartitions)): | |
# Get the dask dataframe for the partition | |
subset_df = papers_df.get_partition(partition) | |
# Check if dataframe is empty | |
if len(subset_df.index) != 0: | |
# Metadata | |
data = [ | |
subset_df[col].values.compute().tolist() | |
for col in ["id", "categories", "title", "abstract", "unix_time"] | |
] | |
# Embeddings | |
data += [ | |
subset_df | |
.map_partitions(emb_gen) | |
.compute()[0] | |
] | |
# Insert data | |
collection.insert(data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Add an ANN index to the collection | |
index_params = { | |
"metric_type": "L2", | |
"index_type": "HNSW", | |
"params": {"efConstruction": 128, "M": 8}, | |
} | |
collection.create_index(field_name="embedding", index_params=index_params) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load the collection into memory | |
collection = Collection(collection_name) | |
collection.load() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def query_and_display(query_text, collection, num_results=10): | |
# Embed the Query Text | |
query_emb = [model.encode(query_text)] | |
# Search Params | |
search_params = {"metric_type": "L2", "params": {"ef": 128}} | |
# Search | |
query_start = datetime.now() | |
results = collection.search( | |
data=query_emb, | |
anns_field="embedding", | |
param=search_params, | |
limit=num_results, | |
expr=None, | |
output_fields=["title", "abstract"], | |
) | |
query_end = datetime.now() | |
# Print Results | |
print(f"Query Speed: {(query_end - query_start).total_seconds():.2f} s") | |
print("Results:") | |
for res in results[0]: | |
title = res.entity.get("title").replace("\n ", "") | |
print(f"➡️ ID: {res.id}. L2 Distance: {res.distance:.2f}") | |
print(f"Title: {title}") | |
print(f"Abstract: {res.entity.get('abstract')}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Query for papers that are similar to the SimCSE paper | |
title = "SimCSE: Simple Contrastive Learning of Sentence Embeddings" | |
abstract = """This paper presents SimCSE, a simple contrastive learning framework that greatly advances state-of-the-art sentence embeddings. We first describe an unsupervised approach, which takes an input sentence and predicts itself in a contrastive objective, with only standard dropout used as noise. This simple method works surprisingly well, performing on par with previous supervised counterparts. We find that dropout acts as minimal data augmentation, and removing it leads to a representation collapse. Then, we propose a supervised approach, which incorporates annotated pairs from natural language inference datasets into our contrastive learning framework by using "entailment" pairs as positives and "contradiction" pairs as hard negatives. We evaluate SimCSE on standard semantic textual similarity (STS) tasks, and our unsupervised and supervised models using BERT base achieve an average of 76.3% and 81.6% Spearman's correlation respectively, a 4.2% and 2.2% improvement compared to the previous best results. We also show -- both theoretically and empirically -- that the contrastive learning objective regularizes pre-trained embeddings' anisotropic space to be more uniform, and it better aligns positive pairs when supervised signals are available.""" | |
query_text = f"{title}[SEP]{abstract}" | |
query_and_display(query_text, collection, num_results=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Release the collection from memory when it's not needed anymore | |
collection.release() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CD into milvus directory | |
cd semantic_similarity/milvus | |
# Shut down milvus | |
sudo docker-compose down | |
# Delete the files | |
sudo rm -rf volumes/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment