Skip to content

Instantly share code, notes, and snippets.

@Blaizzy
Last active May 26, 2025 22:27
Show Gist options
  • Select an option

  • Save Blaizzy/40de0f6b4340490e3920db9e182e6455 to your computer and use it in GitHub Desktop.

Select an option

Save Blaizzy/40de0f6b4340490e3920db9e182e6455 to your computer and use it in GitHub Desktop.
How TRL DataCollatorForCompletionOnlyLM works
```python
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import AutoTokenizer
from datasets import load_dataset
# Load Dataset and tokenizer
dataset = load_dataset('prince-canuma/tinyOrca', split='train')
tokenizer = AutoTokenizer.from_pretrained("prince-canuma/Damysus-2.7B-Chat")
# Specify the field name in your dataset containing messages
messages_field = "messages" # Replace with the actual field name
def format_dataset(examples):
if isinstance(examples[messages_field][0], list):
output_texts = []
for i in range(len(examples[messages_field])):
output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False))
return {"formatted_text": output_texts}
else:
return {"formatted_text": tokenizer.apply_chat_template(examples[messages_field], tokenize=False)}
# Use the .map method to apply chat template to the dataset
formatted_dataset = dataset.map(format_dataset)
# Tokenize one sample
tokenized_sample = tokenizer(formatted_dataset['formatted_text'][0], padding=True, truncation=True, return_special_tokens_mask=True)
# Run data collator
data_collator = DataCollatorForCompletionOnlyLM(response_template='assistant', tokenizer=tokenizer)
data_collator.torch_call([tokenized_sample])
```
Output:
All instructions (i.e., system and user prompt) are assigned -100 which means they are not used when calculating the loss.
Meaning you only train on the model completions.
```shell
{'input_ids': tensor([[50295, 10057, 198, 1639, 389, 281, 9552, 8796, 326, 5419,
661, 1064, 1321, 13, 11787, 481, 345, 1577, 345, 257,
1808, 13, 3406, 4876, 318, 284, 3280, 355, 42353, 355,
345, 460, 13, 2893, 18877, 892, 2239, 12, 1525, 12,
9662, 290, 12051, 534, 3280, 13, 50296, 198, 50295, 7220,
198, 3260, 3597, 40802, 2735, 357, 10333, 8, 5000, 12199,
11, 15608, 18524, 284, 30081, 351, 1180, 3496, 34422, 290,
11408, 329, 2297, 13, 6660, 11, 673, 1444, 5436, 5780,
290, 17537, 1891, 11, 734, 3496, 34422, 290, 11408, 3025,
670, 673, 29382, 11, 284, 2112, 257, 1744, 12438, 13,
383, 19886, 21581, 262, 3721, 329, 366, 1135, 4231, 7236,
10776, 18067, 5157, 17083, 1, 8972, 706, 257, 1545, 286,
15608, 338, 409, 12, 7081, 6726, 6807, 656, 262, 8296,
8034, 290, 5158, 286, 14923, 339, 2982, 326, 15608, 290,
607, 1966, 16506, 547, 15398, 1780, 13, 2293, 262, 1545,
1364, 11, 5780, 290, 17537, 1891, 1965, 15608, 284, 15962,
319, 262, 3307, 286, 262, 2776, 11, 543, 673, 3417,
355, 366, 9032, 510, 11, 651, 736, 1978, 11, 2270,
510, 11, 651, 736, 1978, 11, 655, 11, 334, 456,
11, 262, 5290, 1911, 1649, 5780, 5220, 326, 484, 3551,
546, 262, 4519, 13, 15608, 2540, 2712, 262, 10047, 290,
13777, 11, 366, 1135, 389, 1239, 1683, 35713, 11, 290,
262, 3496, 35456, 8902, 12979, 13, 1375, 3417, 262, 1429,
355, 530, 286, 262, 749, 36102, 6461, 673, 550, 981,
8296, 11, 290, 531, 262, 10530, 4887, 14451, 607, 9027,
13, 1052, 6597, 10651, 286, 607, 36543, 1146, 5486, 546,
2270, 4739, 460, 307, 2982, 878, 262, 2457, 26441, 13,
383, 3496, 318, 7478, 546, 15608, 338, 409, 11, 14757,
18488, 297, 268, 3099, 282, 11, 355, 262, 734, 550,
5445, 510, 287, 3269, 2813, 475, 550, 587, 1775, 319,
257, 3128, 257, 1178, 1528, 1568, 13, 2293, 262, 2650,
286, 262, 2647, 2008, 11, 517, 20195, 17795, 262, 3496,
284, 18488, 297, 268, 3099, 282, 9349, 11, 351, 262,
8674, 2045, 588, 18488, 297, 268, 3099, 282, 11, 262,
8674, 287, 262, 2008, 3501, 607, 257, 36953, 355, 18488,
297, 268, 3099, 282, 550, 7478, 1760, 329, 15608, 290,
257, 42893, 15608, 17326, 287, 262, 2008, 326, 318, 29131,
284, 804, 2092, 284, 326, 286, 543, 18488, 297, 268,
3099, 282, 373, 29586, 284, 423, 1813, 15608, 329, 607,
10955, 13, 198, 198, 15001, 319, 428, 10066, 11, 508,
318, 256, 7167, 14622, 3375, 546, 287, 356, 389, 1239,
1972, 736, 1978, 30, 3914, 338, 892, 3049, 13, 13860,
286, 10510, 25, 50296, 198, 50295, 562, 10167, 198, 29907,
15608, 318, 3375, 546, 607, 409, 12, 7081, 6726, 287,
262, 3496, 366, 1135, 4231, 7236, 10776, 18067, 5157, 17083,
526, 220, 198, 198, 8600, 12, 1525, 12, 9662, 17734,
25, 198, 198, 16, 13, 383, 10066, 15802, 326, 262,
3496, 373, 21581, 706, 257, 1545, 286, 15608, 338, 409,
12, 7081, 6726, 5158, 546, 14923, 286, 511, 30321, 13,
198, 17, 13, 383, 3496, 318, 546, 257, 2776, 810,
262, 3155, 7622, 7163, 510, 290, 1972, 736, 1978, 11,
543, 15608, 8477, 355, 366, 1169, 5290, 526, 198, 18,
13, 383, 10066, 2585, 326, 262, 3496, 318, 7478, 546,
15608, 338, 409, 11, 14757, 18488, 297, 268, 3099, 282,
11, 355, 484, 550, 257, 2092, 319, 12, 392, 12,
2364, 2776, 13, 198, 19, 13, 15891, 20195, 287, 262,
2647, 2008, 11, 884, 355, 262, 8674, 338, 28204, 284,
18488, 297, 268, 3099, 282, 290, 1728, 3709, 11, 2252,
1104, 262, 13367, 326, 262, 3496, 318, 546, 683, 13,
198, 198, 15001, 319, 428, 1321, 11, 340, 318, 6397,
284, 13796, 326, 8121, 15608, 318, 3375, 546, 14757, 18488,
297, 268, 3099, 282, 287, 262, 3496, 366, 1135, 4231,
7236, 10776, 18067, 5157, 17083, 526, 50296, 198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1]]), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
-100, -100, -100, -100, -100, -100, -100, -100, 198, 29907,
15608, 318, 3375, 546, 607, 409, 12, 7081, 6726, 287,
262, 3496, 366, 1135, 4231, 7236, 10776, 18067, 5157, 17083,
526, 220, 198, 198, 8600, 12, 1525, 12, 9662, 17734,
25, 198, 198, 16, 13, 383, 10066, 15802, 326, 262,
3496, 373, 21581, 706, 257, 1545, 286, 15608, 338, 409,
12, 7081, 6726, 5158, 546, 14923, 286, 511, 30321, 13,
198, 17, 13, 383, 3496, 318, 546, 257, 2776, 810,
262, 3155, 7622, 7163, 510, 290, 1972, 736, 1978, 11,
543, 15608, 8477, 355, 366, 1169, 5290, 526, 198, 18,
13, 383, 10066, 2585, 326, 262, 3496, 318, 7478, 546,
15608, 338, 409, 11, 14757, 18488, 297, 268, 3099, 282,
11, 355, 484, 550, 257, 2092, 319, 12, 392, 12,
2364, 2776, 13, 198, 19, 13, 15891, 20195, 287, 262,
2647, 2008, 11, 884, 355, 262, 8674, 338, 28204, 284,
18488, 297, 268, 3099, 282, 290, 1728, 3709, 11, 2252,
1104, 262, 13367, 326, 262, 3496, 318, 546, 683, 13,
198, 198, 15001, 319, 428, 1321, 11, 340, 318, 6397,
284, 13796, 326, 8121, 15608, 318, 3375, 546, 14757, 18488,
297, 268, 3099, 282, 287, 262, 3496, 366, 1135, 4231,
7236, 10776, 18067, 5157, 17083, 526, -100, 198]])}
```
@zfscgy
Copy link

zfscgy commented Mar 13, 2025

Why the attention_mask is all ones. I thought it should be a triangular matrix

@Blaizzy
Copy link
Author

Blaizzy commented Mar 13, 2025

This is a batch request so the mask is expanded to tril/4D at run time.

@Serega6678
Copy link

This is the code I wrote for the Qwen 2.5 VL that masks all messages except the last one
Basically <|im_start|> is the token used in Qwen for message start, 3 letters later is the real payload of the message
In my prompt, only one message belongs to assistant and it's the last one so this is my code

# all start_message <|im_start|> tokens, we keep only the last message so the only message that belongs to the assistant
for cur_num, cur_local_coord in zip(*torch.where(labels == 151644)):
    labels[int(cur_num)][:cur_local_coord + 3] = -100

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment