Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Last active January 12, 2023 02:43
Show Gist options
  • Save NaxAlpha/c6a7c65f40c05af0907b25fd742a8df0 to your computer and use it in GitHub Desktop.
Save NaxAlpha/c6a7c65f40c05af0907b25fd742a8df0 to your computer and use it in GitHub Desktop.
Efficiently stream "The Pile" Dataset directly from the web. requires `pip install zstandard`
import torch
from torch.utils.data import IterableDataset
from transformers import PreTrainedTokenizerBase
from pile import ThePile
class ThePileTokenized(IterableDataset):
def __init__(
self,
base_dataset: ThePile,
tokenizer: PreTrainedTokenizerBase,
max_length: int = 1024,
repeat_factor: int = 1,
):
assert repeat_factor >= 1 # but can be a float
self.pile = base_dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.repeat_factor = repeat_factor
def __iter__(self):
ds = iter(self.pile)
buffer = []
while True:
tokens = self.tokenizer.encode(next(ds)["text"])
buffer += [self.tokenizer.eos_token_id] + tokens
while len(buffer) > self.max_length:
yield torch.tensor(buffer[: self.max_length])
buffer = buffer[self.max_length // self.repeat_factor :]
if __name__ == "__main__":
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer
dataset = ThePileTokenized(
ThePile("train"),
GPT2Tokenizer.from_pretrained("gpt2"),
max_length=1024,
repeat_factor=2,
)
dataloader = DataLoader(
dataset,
batch_size=64,
)
for batch in tqdm(dataloader, smoothing=0.01):
pass
# ~6 iters/s for 1 worker
import json
import time
import random
from typing import Literal
import requests
import zstandard as zstd
from torch.utils.data import IterableDataset, get_worker_info
Subset = Literal["train", "val", "test"]
URLs = {
"val": [
"https://the-eye.eu/public/AI/pile/val.jsonl.zst",
],
"test": [
"https://the-eye.eu/public/AI/pile/test.jsonl.zst",
],
"train": [
"https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
"https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
],
}
def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
line = initial_line
while True:
c = reader.read(buffer_size)
if not c:
raise StopIteration
line += c.decode("utf-8")
if "\n" in line:
break
return line.split("\n", 1)
def _line_streamer(reader, buffer_size=4096):
rest = ""
while True:
try:
line, rest = _read_line_from_stream(
reader,
rest,
buffer_size,
)
yield line
except StopIteration:
break
class ThePile(IterableDataset):
TEXT_BUFFER_SIZE = 4096
def __init__(self, subset: Subset):
self.subset = subset
def __iter__(self):
urls = URLs[self.subset].copy()
while True:
wi = get_worker_info()
seed = wi.id if wi is not None else None
rnd = random.Random(seed)
rnd.shuffle(urls)
for url in urls:
r = requests.get(url, stream=True)
with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
data = json.loads(line)
yield data
if __name__ == "__main__":
from tqdm import tqdm
dataset = ThePile("train")
for data in tqdm(dataset, smoothing=0.01):
pass
# Average: ~2000 samples/sec/worker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment