Last active
May 26, 2025 22:27
-
-
Save Blaizzy/40de0f6b4340490e3920db9e182e6455 to your computer and use it in GitHub Desktop.
How TRL DataCollatorForCompletionOnlyLM works
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ```python | |
| from trl import SFTTrainer, DataCollatorForCompletionOnlyLM | |
| from transformers import AutoTokenizer | |
| from datasets import load_dataset | |
| # Load Dataset and tokenizer | |
| dataset = load_dataset('prince-canuma/tinyOrca', split='train') | |
| tokenizer = AutoTokenizer.from_pretrained("prince-canuma/Damysus-2.7B-Chat") | |
| # Specify the field name in your dataset containing messages | |
| messages_field = "messages" # Replace with the actual field name | |
| def format_dataset(examples): | |
| if isinstance(examples[messages_field][0], list): | |
| output_texts = [] | |
| for i in range(len(examples[messages_field])): | |
| output_texts.append(tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False)) | |
| return {"formatted_text": output_texts} | |
| else: | |
| return {"formatted_text": tokenizer.apply_chat_template(examples[messages_field], tokenize=False)} | |
| # Use the .map method to apply chat template to the dataset | |
| formatted_dataset = dataset.map(format_dataset) | |
| # Tokenize one sample | |
| tokenized_sample = tokenizer(formatted_dataset['formatted_text'][0], padding=True, truncation=True, return_special_tokens_mask=True) | |
| # Run data collator | |
| data_collator = DataCollatorForCompletionOnlyLM(response_template='assistant', tokenizer=tokenizer) | |
| data_collator.torch_call([tokenized_sample]) | |
| ``` | |
| Output: | |
| All instructions (i.e., system and user prompt) are assigned -100 which means they are not used when calculating the loss. | |
| Meaning you only train on the model completions. | |
| ```shell | |
| {'input_ids': tensor([[50295, 10057, 198, 1639, 389, 281, 9552, 8796, 326, 5419, | |
| 661, 1064, 1321, 13, 11787, 481, 345, 1577, 345, 257, | |
| 1808, 13, 3406, 4876, 318, 284, 3280, 355, 42353, 355, | |
| 345, 460, 13, 2893, 18877, 892, 2239, 12, 1525, 12, | |
| 9662, 290, 12051, 534, 3280, 13, 50296, 198, 50295, 7220, | |
| 198, 3260, 3597, 40802, 2735, 357, 10333, 8, 5000, 12199, | |
| 11, 15608, 18524, 284, 30081, 351, 1180, 3496, 34422, 290, | |
| 11408, 329, 2297, 13, 6660, 11, 673, 1444, 5436, 5780, | |
| 290, 17537, 1891, 11, 734, 3496, 34422, 290, 11408, 3025, | |
| 670, 673, 29382, 11, 284, 2112, 257, 1744, 12438, 13, | |
| 383, 19886, 21581, 262, 3721, 329, 366, 1135, 4231, 7236, | |
| 10776, 18067, 5157, 17083, 1, 8972, 706, 257, 1545, 286, | |
| 15608, 338, 409, 12, 7081, 6726, 6807, 656, 262, 8296, | |
| 8034, 290, 5158, 286, 14923, 339, 2982, 326, 15608, 290, | |
| 607, 1966, 16506, 547, 15398, 1780, 13, 2293, 262, 1545, | |
| 1364, 11, 5780, 290, 17537, 1891, 1965, 15608, 284, 15962, | |
| 319, 262, 3307, 286, 262, 2776, 11, 543, 673, 3417, | |
| 355, 366, 9032, 510, 11, 651, 736, 1978, 11, 2270, | |
| 510, 11, 651, 736, 1978, 11, 655, 11, 334, 456, | |
| 11, 262, 5290, 1911, 1649, 5780, 5220, 326, 484, 3551, | |
| 546, 262, 4519, 13, 15608, 2540, 2712, 262, 10047, 290, | |
| 13777, 11, 366, 1135, 389, 1239, 1683, 35713, 11, 290, | |
| 262, 3496, 35456, 8902, 12979, 13, 1375, 3417, 262, 1429, | |
| 355, 530, 286, 262, 749, 36102, 6461, 673, 550, 981, | |
| 8296, 11, 290, 531, 262, 10530, 4887, 14451, 607, 9027, | |
| 13, 1052, 6597, 10651, 286, 607, 36543, 1146, 5486, 546, | |
| 2270, 4739, 460, 307, 2982, 878, 262, 2457, 26441, 13, | |
| 383, 3496, 318, 7478, 546, 15608, 338, 409, 11, 14757, | |
| 18488, 297, 268, 3099, 282, 11, 355, 262, 734, 550, | |
| 5445, 510, 287, 3269, 2813, 475, 550, 587, 1775, 319, | |
| 257, 3128, 257, 1178, 1528, 1568, 13, 2293, 262, 2650, | |
| 286, 262, 2647, 2008, 11, 517, 20195, 17795, 262, 3496, | |
| 284, 18488, 297, 268, 3099, 282, 9349, 11, 351, 262, | |
| 8674, 2045, 588, 18488, 297, 268, 3099, 282, 11, 262, | |
| 8674, 287, 262, 2008, 3501, 607, 257, 36953, 355, 18488, | |
| 297, 268, 3099, 282, 550, 7478, 1760, 329, 15608, 290, | |
| 257, 42893, 15608, 17326, 287, 262, 2008, 326, 318, 29131, | |
| 284, 804, 2092, 284, 326, 286, 543, 18488, 297, 268, | |
| 3099, 282, 373, 29586, 284, 423, 1813, 15608, 329, 607, | |
| 10955, 13, 198, 198, 15001, 319, 428, 10066, 11, 508, | |
| 318, 256, 7167, 14622, 3375, 546, 287, 356, 389, 1239, | |
| 1972, 736, 1978, 30, 3914, 338, 892, 3049, 13, 13860, | |
| 286, 10510, 25, 50296, 198, 50295, 562, 10167, 198, 29907, | |
| 15608, 318, 3375, 546, 607, 409, 12, 7081, 6726, 287, | |
| 262, 3496, 366, 1135, 4231, 7236, 10776, 18067, 5157, 17083, | |
| 526, 220, 198, 198, 8600, 12, 1525, 12, 9662, 17734, | |
| 25, 198, 198, 16, 13, 383, 10066, 15802, 326, 262, | |
| 3496, 373, 21581, 706, 257, 1545, 286, 15608, 338, 409, | |
| 12, 7081, 6726, 5158, 546, 14923, 286, 511, 30321, 13, | |
| 198, 17, 13, 383, 3496, 318, 546, 257, 2776, 810, | |
| 262, 3155, 7622, 7163, 510, 290, 1972, 736, 1978, 11, | |
| 543, 15608, 8477, 355, 366, 1169, 5290, 526, 198, 18, | |
| 13, 383, 10066, 2585, 326, 262, 3496, 318, 7478, 546, | |
| 15608, 338, 409, 11, 14757, 18488, 297, 268, 3099, 282, | |
| 11, 355, 484, 550, 257, 2092, 319, 12, 392, 12, | |
| 2364, 2776, 13, 198, 19, 13, 15891, 20195, 287, 262, | |
| 2647, 2008, 11, 884, 355, 262, 8674, 338, 28204, 284, | |
| 18488, 297, 268, 3099, 282, 290, 1728, 3709, 11, 2252, | |
| 1104, 262, 13367, 326, 262, 3496, 318, 546, 683, 13, | |
| 198, 198, 15001, 319, 428, 1321, 11, 340, 318, 6397, | |
| 284, 13796, 326, 8121, 15608, 318, 3375, 546, 14757, 18488, | |
| 297, 268, 3099, 282, 287, 262, 3496, 366, 1135, 4231, | |
| 7236, 10776, 18067, 5157, 17083, 526, 50296, 198]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, | |
| 1, 1, 1, 1]]), 'labels': tensor([[ -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, | |
| -100, -100, -100, -100, -100, -100, -100, -100, 198, 29907, | |
| 15608, 318, 3375, 546, 607, 409, 12, 7081, 6726, 287, | |
| 262, 3496, 366, 1135, 4231, 7236, 10776, 18067, 5157, 17083, | |
| 526, 220, 198, 198, 8600, 12, 1525, 12, 9662, 17734, | |
| 25, 198, 198, 16, 13, 383, 10066, 15802, 326, 262, | |
| 3496, 373, 21581, 706, 257, 1545, 286, 15608, 338, 409, | |
| 12, 7081, 6726, 5158, 546, 14923, 286, 511, 30321, 13, | |
| 198, 17, 13, 383, 3496, 318, 546, 257, 2776, 810, | |
| 262, 3155, 7622, 7163, 510, 290, 1972, 736, 1978, 11, | |
| 543, 15608, 8477, 355, 366, 1169, 5290, 526, 198, 18, | |
| 13, 383, 10066, 2585, 326, 262, 3496, 318, 7478, 546, | |
| 15608, 338, 409, 11, 14757, 18488, 297, 268, 3099, 282, | |
| 11, 355, 484, 550, 257, 2092, 319, 12, 392, 12, | |
| 2364, 2776, 13, 198, 19, 13, 15891, 20195, 287, 262, | |
| 2647, 2008, 11, 884, 355, 262, 8674, 338, 28204, 284, | |
| 18488, 297, 268, 3099, 282, 290, 1728, 3709, 11, 2252, | |
| 1104, 262, 13367, 326, 262, 3496, 318, 546, 683, 13, | |
| 198, 198, 15001, 319, 428, 1321, 11, 340, 318, 6397, | |
| 284, 13796, 326, 8121, 15608, 318, 3375, 546, 14757, 18488, | |
| 297, 268, 3099, 282, 287, 262, 3496, 366, 1135, 4231, | |
| 7236, 10776, 18067, 5157, 17083, 526, -100, 198]])} | |
| ``` |
Author
This is a batch request so the mask is expanded to tril/4D at run time.
This is the code I wrote for the Qwen 2.5 VL that masks all messages except the last one
Basically <|im_start|> is the token used in Qwen for message start, 3 letters later is the real payload of the message
In my prompt, only one message belongs to assistant and it's the last one so this is my code
# all start_message <|im_start|> tokens, we keep only the last message so the only message that belongs to the assistant
for cur_num, cur_local_coord in zip(*torch.where(labels == 151644)):
labels[int(cur_num)][:cur_local_coord + 3] = -100
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Why the attention_mask is all ones. I thought it should be a triangular matrix