Last active
June 24, 2024 21:17
-
-
Save joecummings/05586af0a08eef0714c7da3c56ee7365 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from functools import partial | |
########## | |
from typing import Any, Dict, List, Optional | |
import psutil | |
import torch | |
from torch.nn import functional as F | |
from torch.utils.data import DataLoader, Dataset | |
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX | |
from torchtune.datasets import instruct_dataset, PackedDataset | |
from torchtune.models.llama3 import llama3_tokenizer | |
from torchtune.utils import get_world_size_and_rank, padded_collate | |
from tqdm import tqdm | |
class OldPackedDataset(Dataset): | |
""" | |
Performs greedy sample packing on a provided dataset. This is done as a single | |
preprocessing step before training begins. Shuffling is done outside of this | |
class on packed samples with a ``Sampler`` as part of the dataloader. Currently, | |
this only supports in-memory map-style datasets. | |
The class loads, tokenizes, and packs examples on initialization - no tokenization is done during training. | |
The general flow on initialization is: load tokenized sample -> add to buffer -> | |
when buffer is long enough, add to ``self.packs``. | |
During training, returns self.packs[idx] as input, label, attention mask, and | |
position ids. The attention mask is a lower triangular block mask to prevent | |
samples from cross-attending within a pack. The position ids indicate the position | |
of each token relative to its sample within a pack. These are all padded to max | |
sequence length, so a batch-wise collator is not needed. | |
A packed sample is made up of individual smaller sequence length samples jammed together | |
within ``max_seq_len``. For example, if max_seq_len is 6 and there are varied | |
length samples:: | |
tokens = [ | |
[S1, S1, S1, S2, S2, pad], | |
[S3, S3, S4, S4, pad, pad], | |
..., | |
] | |
To prevent cross-contamination, the following mask would be returned for the | |
first pack in the example:: | |
mask = [ | |
[1, 0, 0, 0, 0, 0], | |
[1, 1, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0], | |
[0, 0, 0, 1, 0, 0], | |
[0, 0, 0, 1, 1, 0], | |
[0, 0, 0, 0, 0, 1], | |
] | |
The position ids would be:: | |
input_pos = [ | |
[0, 1, 2, 0, 1, 2], | |
[0, 1, 0, 1, 2, 3], | |
..., | |
] | |
The identity matrix is used in the mask for pad tokens instead of a causal mask. | |
For position ids for pad tokens, we simply continue to increment from the previous | |
sample normally. | |
Args: | |
ds (Dataset): dataset to sample pack. This should return a dictionary with field | |
"tokens" and "labels" containing the tokenized and label samples. | |
max_seq_len (int): Maximum number of tokens to pack | |
max_packs (Optional[int]): maximum number of packs. Default is None, which will create as many | |
packs as possible. | |
split_across_pack (bool): if the last sample in a pack does not fit in ``max_seq_len``, | |
split the sample into the next pack, or move it entirely to the beginning of the next pack. | |
For pre-training, typically this is set to True for general text completion. For | |
fine-tuning, typically this is set to False to avoid truncating sentences in instruct | |
tuning. Default is False. | |
""" | |
def __init__( | |
self, | |
ds: Dataset, | |
max_seq_len: int, | |
max_packs: Optional[int] = None, | |
split_across_pack: bool = False, | |
) -> None: | |
self.ds = ds | |
self.max_seq_len = max_seq_len | |
self.max_packs = max_packs | |
self.split_across_pack = split_across_pack | |
# where final samples will be held | |
self.packs: List[Dict[str, List[int]]] = [] | |
self._pack() | |
def _pack(self) -> None: | |
""" | |
Iterate through the dataset. Use a buffer to hold samples until max_seq_len, | |
then append the buffer to self.packs as a single "packed" sample. Continue | |
until max_packs or end of dataset. | |
""" | |
# buffer to hold samples until they are long enough to be added to self.packs | |
current_pack = { | |
"tokens": [], | |
"labels": [], | |
"mask": [], | |
"input_pos": [], | |
} | |
# Keep track of what index the previous sample ends in case we need | |
# to end a pack early | |
previous_sample_boundary = 0 | |
# Only show progress bar on rank 0 | |
_, rank = get_world_size_and_rank() | |
if rank == 0: | |
pbar = tqdm(total=len(self.ds), desc="Packing dataset", dynamic_ncols=True) | |
for sample in self.ds: | |
tokens, labels = sample["tokens"], sample["labels"] | |
# If the dataset outputs samples that are larger than the specified | |
# max_seq_len and we're unable to split it, user needs to modify | |
# one of the two parameters | |
seq_len = len(tokens) | |
if seq_len > self.max_seq_len and not self.split_across_pack: | |
raise ValueError( | |
f"Dataset sample is too long ({len(tokens)} > {self.max_seq_len}). " | |
"Please set `split_across_pack=True` or increase `max_seq_len`." | |
) | |
# Create integer mask and position ids for current sample and extend | |
# current pack | |
current_sample = { | |
"tokens": tokens, | |
"labels": labels, | |
# Mask is simply a causal mask within this sample length | |
"mask": [torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))], | |
"input_pos": list(range(seq_len)), | |
} | |
current_pack = {k: v + current_sample[k] for k, v in current_pack.items()} | |
# If the current pack is long enough, add it to self.packs and retain | |
# any truncated samples for next pack, if splitting samples | |
if len(current_pack["tokens"]) > self.max_seq_len: | |
current_pack = self._add_pack( | |
current_pack=current_pack, | |
boundary=( | |
self.max_seq_len | |
if self.split_across_pack | |
else previous_sample_boundary | |
), | |
) | |
if rank == 0: | |
pbar.update() | |
previous_sample_boundary = len(current_pack["tokens"]) | |
if self.max_packs is not None and len(self.packs) >= self.max_packs: | |
break | |
# Add the last pack with remaining samples that did not fit in previous | |
if len(current_pack["tokens"]) > 0 and ( | |
self.max_packs is None or len(self.packs) < self.max_packs | |
): | |
current_pack = self._add_pack( | |
current_pack=current_pack, boundary=len(current_pack["tokens"]) | |
) | |
assert len(current_pack["tokens"]) == 0 | |
def _add_pack( | |
self, current_pack: Dict[str, List[int]], boundary: int | |
) -> Dict[str, List[int]]: | |
""" | |
Pad and add the current pack to self.packs and return what's remaining. | |
""" | |
# So far we've kept a list of causal masks, one for each sample in the pack. | |
# Now we need to combine them into a single mask for the entire pack. | |
packing_mask = torch.block_diag(*current_pack["mask"]) | |
pack = { | |
"tokens": current_pack["tokens"][:boundary], | |
"labels": current_pack["labels"][:boundary], | |
"mask": packing_mask[:boundary, :boundary], | |
"input_pos": current_pack["input_pos"][:boundary], | |
} | |
# Resultant shapes after padding | |
# tokens: [max_seq_len, ] | |
# labels: [max_seq_len, ] | |
# mask: [max_seq_len, max_seq_len] | |
# input_pos: [max_seq_len, ] | |
padded_pack = self._pad_pack(pack) | |
self.packs.append(padded_pack) | |
# Keep sample that did not fit or got truncated for the next pack | |
updated_pack = { | |
"tokens": current_pack["tokens"][boundary:], | |
"labels": current_pack["labels"][boundary:], | |
"mask": [packing_mask[boundary:, boundary:]], | |
"input_pos": current_pack["input_pos"][boundary:], | |
} | |
return updated_pack | |
def _pad_pack( | |
self, | |
sample: Dict[str, Any], | |
padding_idx: int = 0, | |
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX, | |
) -> Dict[str, torch.Tensor]: | |
"""Pad a group of packed sequences to max sequence length in the group, and | |
convert integer lists to tensors. Account for attention mask and position | |
ids. | |
This is a sample-wise collator and should not be used with the dataloader. | |
Sample should look like:: | |
{ | |
"tokens": List[int], | |
"labels": List[int], | |
"mask": Tensor, | |
"input_pos": List[int], | |
} | |
Args: | |
sample (Dict[str, Any]): A dictionary containing tokens, labels, mask, | |
and position ids | |
padding_idx (int): Padding index for input ids. Defaults to 0. | |
ignore_idx (int): Padding index for labels. Defaults to -100. | |
Returns: | |
Collated tokens, labels, mask, and position ids. | |
""" | |
tokens = sample["tokens"] | |
labels = sample["labels"] | |
mask = sample["mask"] | |
input_pos = sample["input_pos"] | |
# Pad to max sequence length | |
tokens = F.pad( | |
torch.tensor(tokens), (0, self.max_seq_len - len(tokens)), value=padding_idx | |
) | |
labels = F.pad( | |
torch.tensor(labels), (0, self.max_seq_len - len(labels)), value=ignore_idx | |
) | |
# For attention mask, simply use identity matrix for the pad tokens | |
mask_pad = torch.eye(self.max_seq_len - mask.shape[0], dtype=torch.bool) | |
mask = torch.block_diag(mask, mask_pad) | |
# For position ids, continue to increment for pad tokens | |
next_pos = input_pos[-1] + 1 | |
input_pos_pad = torch.arange( | |
next_pos, next_pos + self.max_seq_len - len(input_pos) | |
) | |
# Do not go beyond max_seq_len - 1 | |
input_pos_pad = input_pos_pad.clamp(max=self.max_seq_len - 1) | |
input_pos = torch.cat( | |
[ | |
torch.tensor(input_pos), | |
input_pos_pad, | |
] | |
) | |
assert tokens.shape == labels.shape == input_pos.shape | |
assert tokens.shape[0] == mask.shape[0] | |
return { | |
"tokens": tokens, | |
"labels": labels, | |
"mask": mask, | |
"input_pos": input_pos, | |
} | |
def __len__(self): | |
return len(self.packs) | |
def __getitem__(self, index: int) -> Dict[str, List[int]]: | |
return self.packs[index] | |
########## | |
tokenizer = llama3_tokenizer("./model/original/tokenizer.model") | |
dataset = instruct_dataset( | |
tokenizer=tokenizer, | |
source="TIGER-Lab/WebInstructSub", | |
template="torchtune.data.AlpacaInstructTemplate", | |
column_map={ | |
"instruction": "question", | |
"output": "answer", | |
}, | |
max_seq_len=4096, | |
packed=False, | |
split=f"train[:1%]", | |
) | |
baseline_virtual_memory_used = psutil.virtual_memory().used | |
# Pack using the new impl | |
start_pack = time.time() | |
new_packed_dataset = PackedDataset( | |
dataset, | |
max_seq_len=4096, | |
split_across_pack=False, | |
) | |
end_pack = time.time() | |
print( | |
f"New implementation packed in: {end_pack - start_pack} with {(psutil.virtual_memory().used - baseline_virtual_memory_used) / 1e9} GB" | |
) | |
# Load using the new impl | |
for bs in [4, 8, 32, 64, 128, 256]: | |
for batch in DataLoader(new_packed_dataset, batch_size=bs): | |
pass | |
# Pack using the old impl | |
start_pack = time.time() | |
old_packed_dataset = OldPackedDataset( | |
dataset, | |
max_seq_len=4096, | |
split_across_pack=False, | |
) | |
end_pack = time.time() | |
print( | |
f"Old implementation packed in: {end_pack - start_pack} with {(psutil.virtual_memory().used - baseline_virtual_memory_used) / 1e9} GB" | |
) | |
# Load using the old impl | |
for bs in [4, 8, 32, 64, 128, 256]: | |
for batch in DataLoader(old_packed_dataset, batch_size=bs): | |
pass | |
# Plot memory usage | |
print( | |
f"Total memory used: {(psutil.virtual_memory().used - baseline_virtual_memory_used) / 1e9} GB" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment