Skip to content

Instantly share code, notes, and snippets.

@sshleifer
Created October 7, 2020 19:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sshleifer/c4aed7bf4418b50caee731e94be05d9f to your computer and use it in GitHub Desktop.
Save sshleifer/c4aed7bf4418b50caee731e94be05d9f to your computer and use it in GitHub Desktop.
Fetching summarization datasets
from pathlib import Path
import fire
from tqdm import tqdm
DS_TO_KEY = {
'gigaword': ('document', 'summary'),
'xsum': ('document', 'summary'),
'aeslc': ('email_body', 'subject_line'),
'billsum': ('text', 'summary'), # title also a field...
}
def download_summarization_dataset(dataset, save_dir=None, split=None, **load_kwargs) -> None:
"""Download a dataset using the datasets package and save it to the format expected by finetune.py
Format of save_dir: train.source, train.target, val.source, val.target, test.source, test.target.
Args:
dataset: <str> xsum, aeslc etc.
save_dir: <str>, where to save the datasets, defaults to f'{dataset}-{src_lang}-{tgt_lang}'
Usage:
>>> download_summarization_dataset('xsum', split='test') # saves to wmt16-ro-en
"""
try:
import datasets
except (ModuleNotFoundError, ImportError):
raise ImportError("run pip install datasets")
#pair = f"{src_lang}-{tgt_lang}"
#print(f"Converting {dataset}-{pair}")
ds = datasets.load_dataset(dataset, split=split, **load_kwargs)
if save_dir is None:
save_dir = dataset
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True)
for split in ds.keys():
print(f"Splitting {split} with {ds[split].num_rows} records")
# to save to val.source, val.target like summary datasets
fn = "val" if split == "validation" else split
src_path = save_dir.joinpath(f"{fn}.source")
tgt_path = save_dir.joinpath(f"{fn}.target")
src_fp = src_path.open("w+")
tgt_fp = tgt_path.open("w+")
if dataset in DS_TO_KEY:
src_key, tgt_key = DS_TO_KEY[dataset]
else:
src_key, tgt_key = ('document', 'summary')
# reader is the bottleneck so writing one record at a time doesn't slow things down
for x in tqdm(ds[split]):
try:
src_fp.write(x[src_key] + "\n")
tgt_fp.write(x[tgt_key] + "\n")
except KeyError:
raise KeyError(f'Keys are {x.keys()}')
print(f"Saved {dataset} dataset to {save_dir}")
if __name__ == "__main__":
fire.Fire(download_summarization_dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment