Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Created December 12, 2022 02:33
Show Gist options
  • Save NaxAlpha/000f20e6758aba2be2cea317d435ae50 to your computer and use it in GitHub Desktop.
Save NaxAlpha/000f20e6758aba2be2cea317d435ae50 to your computer and use it in GitHub Desktop.
When sequence lengths are small, it takes some time to fetch from the HuggingFace dataset server. So to keep feed data to the model, we need to cache already fetched files in memory and feed one of those every time.
import json
import torch
import random
from time import sleep
from threading import Thread
from datasets import load_dataset
from transformers import GPT2Tokenizer
from torch.utils.data import Dataset, get_worker_info
def cycled(itr):
while True:
for itm in itr:
yield itm
class C4X(Dataset):
def __init__(self, seq_len=512, split='train'):
self.seq = seq_len
self.ds = load_dataset(
'c4',
name='en',
split=split,
streaming=True,
)
self.tok = GPT2Tokenizer.from_pretrained('gpt2')
self.init = False
self.queue = []
def __len__(self):
return 1_000_000_000
def _init(self):
if self.init:
return
wi = get_worker_info()
self.ds = cycled(
self.ds.shuffle(
seed=wi.seed if wi else 0,
buffer_size=1_000,
)
)
Thread(
target=self.worker,
daemon=True,
).start()
self.init = True
def worker(self):
try:
while True:
obj = next(self.ds)['text']
tkn = self.tok.encode(obj)
self.queue.append(tkn)
self.queue = self.queue[-10000:]
except KeyboardInterrupt:
pass
def _get_next(self):
self._init()
while not self.queue:
sleep(0.1)
return random.choice(self.queue)
def _get_full(self):
obj = []
while len(obj) < self.seq:
obj += self._get_next()
obj.append(self.tok.eos_token_id)
s = random.randint(0, len(obj)-self.seq)
return obj[s:s+self.seq]
def __getitem__(self, _):
return torch.tensor(self._get_full())
def decode(self, tkns):
return self.tok.decode(tkns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment