Skip to content

Instantly share code, notes, and snippets.

@raphael-sch
Last active May 7, 2024 16:28
Show Gist options
  • Save raphael-sch/c38c4dbaecce62566bc3f07559678360 to your computer and use it in GitHub Desktop.
Save raphael-sch/c38c4dbaecce62566bc3f07559678360 to your computer and use it in GitHub Desktop.
Training and position_ids with left padding
import argparse
import transformers
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
parser = argparse.ArgumentParser(description='Define experiment parameters')
parser.add_argument('--use_custom_position_ids', default='no', choices=['yes', 'no'], type=str)
parser.add_argument('--model_name', default='meta-llama/Llama-2-7b-hf', type=str)
parser.add_argument('--hf_auth_token', default=None, type=str)
parser.add_argument('--seed', default=1111, type=int)
opts = parser.parse_args()
torch.cuda.manual_seed_all(opts.seed)
torch.backends.cudnn.deterministic = True
# different loss
# GPT-Neo
# llama
# same loss
# opt # correct padding_id creation
# bloom # relative embeddings
def main():
tokenizer = AutoTokenizer.from_pretrained(opts.model_name, token=opts.hf_auth_token)
model = AutoModelForCausalLM.from_pretrained(opts.model_name,
torch_dtype=torch.float16,
device_map="auto",
token=opts.hf_auth_token
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'
texts = [
dict(text="This is a sentence."),
dict(text="This is a much longer sentence.")
]
def tokenize(d):
inputs = tokenizer(d['text'])
inputs['labels'] = inputs['input_ids']
return inputs
train_dataset = Dataset.from_list(texts)
train_dataset = train_dataset.map(tokenize)
train_args = transformers.TrainingArguments(
per_device_train_batch_size=2,
num_train_epochs=1,
output_dir='./',
save_strategy='no',
seed=opts.seed,
data_seed=opts.seed
)
trainer = CustomTrainer(
use_custom_position_ids=opts.use_custom_position_ids == 'yes',
model=model,
train_dataset=train_dataset,
args=train_args,
data_collator=transformers.DataCollatorForSeq2Seq(tokenizer, padding=True)
)
out = trainer.train()
print('Training Loss:', out.metrics['train_loss'])
class CustomTrainer(transformers.Trainer):
def __init__(self, use_custom_position_ids=False, **kwargs):
super().__init__(**kwargs)
self.use_custom_position_ids = use_custom_position_ids
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs['input_ids']
print('padded input_ids:', input_ids.tolist())
if self.use_custom_position_ids:
position_ids = list()
for _input_ids in input_ids:
_position_ids = list()
position_id = 0
for token_id in _input_ids:
_position_ids.append(position_id)
if token_id != self.data_collator.tokenizer.pad_token_id:
position_id += 1
position_ids.append(_position_ids)
print('Create custom position_ids with appropriate padding:')
print(position_ids)
inputs['position_ids'] = torch.tensor(position_ids).to(model.device)
else:
assert 'position_ids' not in inputs
print('There are no position_ids in the input.')
print('The transformers model implementation will create position_ids based on the max_length of the batch.')
max_length = input_ids.shape[-1]
position_ids = [list(range(max_length)) for _ in input_ids]
print(position_ids)
print('They are incorrect with left padding!')
return super().compute_loss(model, inputs, return_outputs=return_outputs)
if __name__ == '__main__':
main()
@raphael-sch
Copy link
Author

run with

python minimal_example.py --model_name EleutherAI/gpt-neo-1.3B --use_custom_position_ids no

and

python minimal_example.py --model_name EleutherAI/gpt-neo-1.3B --use_custom_position_ids yes

and compare training loss

@Emveez
Copy link

Emveez commented May 7, 2024

If the tokenizer returns the attention mask the position ids could be generated with cumulative sum as

attention_mask = torch.tensor([[0,0,0,0,1,1,1], [0,0,0,1,1,1,1]])
position_ids = attention_mask.cumsum(dim=1)
print(position_ids) # tensor([[0, 0, 0, 0, 1, 2, 3],[0, 0, 0, 1, 2, 3, 4]])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment