Skip to content

Instantly share code, notes, and snippets.

@arose13
Created November 30, 2023 03:45
Show Gist options
  • Save arose13/b0db4de5233ce16f6e0adb3900fe346e to your computer and use it in GitHub Desktop.
Save arose13/b0db4de5233ce16f6e0adb3900fe346e to your computer and use it in GitHub Desktop.
LLM continously concat sequence of tokens
def custom_dataloader(dataset: Dataset, batch_size=16):
random_indices = torch.randperm(len(dataset['tokens']) - context_size)
for idx in range(0, len(random_indices), batch_size):
x = torch.stack([
dataset['tokens'][i: i+context_size]
for i in random_indices[idx: idx+batch_size]
])
y = torch.stack([
dataset['tokens'][i+1: i+context_size+1]
for i in random_indices[idx: idx+batch_size]
])
yield x, y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment