Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Created November 23, 2022 01:18
Show Gist options
  • Save NaxAlpha/d2da09a0de5c85962bbe42c929f4027a to your computer and use it in GitHub Desktop.
Save NaxAlpha/d2da09a0de5c85962bbe42c929f4027a to your computer and use it in GitHub Desktop.
# stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
import json
import torch
import random
from datasets import load_dataset
from transformers import GPT2Tokenizer
from torch.utils.data import Dataset, get_worker_info
def cycled(itr):
while True:
for itm in itr:
yield itm
class C4X(Dataset):
def __init__(self, seq_len=512, split='train'):
self.seq = seq_len
self.ds = load_dataset(
'c4',
name='en',
split=split,
streaming=True,
)
self.tok = GPT2Tokenizer.from_pretrained('gpt2')
self.init = False
def __len__(self):
return 1_000_000_000
def _init(self):
if self.init:
return
wi = get_worker_info()
self.ds = cycled(
self.ds.shuffle(
seed=wi.seed,
buffer_size=10_000,
)
)
self.init = True
def _get_next(self):
self._init()
obj = next(self.ds)['text']
tkn = self.tok.encode(obj)
return tkn
def _get_full(self):
obj = []
while len(obj) < self.seq:
obj += self._get_next()
obj.append(self.tok.eos_token_id)
s = random.randint(0, len(obj)-self.seq)
return obj[s:s+self.seq]
def __getitem__(self, _):
return torch.tensor(self._get_full())
def decode(self, tkns):
return self.tok.decode(tkns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment