Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created June 3, 2023 21:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Birch-san/3eff9ccd9e4ba096d57caa490272d71a to your computer and use it in GitHub Desktop.
Save Birch-san/3eff9ccd9e4ba096d57caa490272d71a to your computer and use it in GitHub Desktop.
DataCollatorForCriticLM
class ExtractedCriticSample(TypedDict):
prompt: str
continuation: str
rating: int
@dataclass
class DataCollatorForCriticLM(object):
tokenizer: transformers.PreTrainedTokenizer
prompt_max_len: int
continuation_max_len: int
def __call__(self, instances: Sequence[ExtractedCriticSample]) -> Dict[str, torch.Tensor]:
# we deliberately avoid inserting BOS, because we truncate from the left. better to add after truncation
prompts = [example['prompt'] for example in instances]
continuations = [f"{example['continuation']}{self.tokenizer.eos_token}" for example in instances]
# Tokenize
tokenized_prompts = self.tokenizer(
prompts,
# we want to apply truncation in the opposite direction; discard the start instead of the end,
# so that the continuation follows from the end-of-prompt
max_length=None,
truncation=False,
add_special_tokens=False,
)
# retain the rightmost self.prompt_max_len-1 tokens, add a BOS token at the start
tokenized_prompts = [self.tokenizer.bos_token_id, *tokenized_prompts[1-self.prompt_max_len:]]
tokenized_continuations = self.tokenizer(
continuations,
max_length=self.continuation_max_len,
truncation=True,
add_special_tokens=False,
)
continueds: List[LongTensor] = []
continuation_masks: List[BoolTensor] = []
for tokenized_prompt, tokenized_continuation in zip(
tokenized_prompts['input_ids'],
tokenized_continuations['input_ids']
):
continued: LongTensor = torch.tensor(tokenized_prompt + tokenized_continuation)
continueds.append(continued)
continuation_mask: BoolTensor = torch.arange(0, continued.size(-1)) >= len(tokenized_prompt)
continuation_masks.append(continuation_mask)
# Apply padding
continueds = pad_sequence(continueds, batch_first=True, padding_value=self.tokenizer.pad_token_id)
continuation_masks = pad_sequence(continuation_masks, batch_first=True, padding_value=0)
data_dict = {
'input_ids': continueds,
'attention_mask': continueds.ne(self.tokenizer.pad_token_id),
'continuation_masks': continuation_masks,
}
return data_dict
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment