Last active
October 25, 2023 13:52
-
-
Save jxmorris12/3d943a30b35b5f3908a896acc62ce696 to your computer and use it in GitHub Desktop.
load msmarco corpus
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Dict, Tuple | |
import logging | |
import os | |
import pathlib | |
import requests | |
import zipfile | |
import beir | |
import beir.datasets | |
import datasets | |
import tqdm | |
def download_url(url: str, save_path: str, chunk_size: int = 1024): | |
"""Download url with progress bar using tqdm | |
https://stackoverflow.com/questions/15644964/python-progress-bar-and-downloads | |
Args: | |
url (str): downloadable url | |
save_path (str): local path to save the downloaded file | |
chunk_size (int, optional): chunking of files. Defaults to 1024. | |
""" | |
r = requests.get(url, stream=True) | |
total = int(r.headers.get('Content-Length', 0)) | |
with open(save_path, 'wb') as fd, tqdm.tqdm( | |
desc=save_path, | |
total=total, | |
unit='iB', | |
unit_scale=True, | |
unit_divisor=chunk_size, | |
) as bar: | |
for data in r.iter_content(chunk_size=chunk_size): | |
size = fd.write(data) | |
bar.update(size) | |
def unzip(zip_file: str, out_dir: str): | |
zip_ = zipfile.ZipFile(zip_file, "r") | |
zip_.extractall(path=out_dir) | |
zip_.close() | |
def download_url_and_unzip(url: str, out_dir: str, chunk_size: int = 1024) -> str: | |
os.makedirs(out_dir, exist_ok=True) | |
dataset = url.split("/")[-1] | |
zip_file = os.path.join(out_dir, dataset) | |
if not os.path.isfile(zip_file): | |
logging.info("Downloading {} ...".format(dataset)) | |
download_url(url, zip_file, chunk_size) | |
if not os.path.isdir(zip_file.replace(".zip", "")): | |
logging.info("Unzipping {} ...".format(dataset)) | |
unzip(zip_file, out_dir) | |
return os.path.join(out_dir, dataset.replace(".zip", "")) | |
def load_beir_uncached(dataset: str, split: str) -> Tuple[datasets.Dataset, datasets.Dataset, Dict[str, Dict[str, int]], Dict]: | |
"""Loads a BEIR test dataset through tools provided by BeIR. | |
Returns: | |
corpus (datasets.Dataset): Corpus of documents | |
keys -- corpus_id, text | |
queries (datasets.Dataset): Corpus of queries | |
keys -- query_id, text | |
qrels | |
ance_results | |
""" | |
dataset = "msmarco" | |
split = "train" | |
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) | |
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets") | |
data_path = download_url_and_unzip(url, out_dir) | |
corpus, queries, qrels = beir.datasets.data_loader.GenericDataLoader(data_path).load(split=split) | |
ance_results = get_ance_results(dataset=dataset, corpus=corpus, queries=queries) | |
corpus = datasets.Dataset.from_list( | |
[{"id": k, "text": v["text"]} for k,v in corpus.items()]) | |
return corpus |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment