Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Last active May 21, 2024 19:20
Show Gist options
  • Save vwxyzjn/ec4e30cd82f2cad14c7412181eddbc7b to your computer and use it in GitHub Desktop.
Save vwxyzjn/ec4e30cd82f2cad14c7412181eddbc7b to your computer and use it in GitHub Desktop.
from typing import Tuple
import torch
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def first_true_indices(bools: torch.Tensor, dtype=torch.long):
"""
Takes an N-dimensional bool tensor and returns an (N-1)-dimensional tensor of integers giving
the position of the first True in each "row".
Returns the length of the rows (bools.size(-1)) if no element is True in a given row.
Args:
bools (`torch.Tensor`):
An N-dimensional boolean tensor.
dtype (`torch.dtype`, optional):
The desired data type of the output tensor. Defaults to `torch.long`.
Returns:
`torch.Tensor`:
An (N-1)-dimensional tensor of integers indicating the position of the first True
in each row. If no True value is found in a row, returns the length of the row.
"""
row_len = bools.size(-1)
zero_or_index = row_len * (~bools).type(dtype) + torch.arange(row_len, dtype=dtype, device=bools.device)
return torch.min(zero_or_index, dim=-1).values
def get_reward(
model: torch.nn.Module, query_responses: torch.Tensor, pad_token_id: int, context_length: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the reward logits and the rewards for a given model and query responses.
Args:
model (`torch.nn.Module`):
The model used to compute the reward logits.
query_responses (`torch.Tensor`):
The tensor containing the query responses.
pad_token_id (`int`):
The token ID representing the pad token.
context_length (`int`):
The length of the context in the query responses.
Returns:
tuple:
- `reward_logits` (`torch.Tensor`):
The logits for the reward model.
- `final_rewards` (`torch.Tensor`):
The final rewards for each query response.
- `sequence_lengths` (`torch.Tensor`):
The lengths of the sequences in the query responses.
"""
attention_mask = query_responses != pad_token_id
position_ids = attention_mask.cumsum(1) - attention_mask.long() # exclusive cumsum
lm_backbone = getattr(model, model.base_model_prefix)
input_ids = torch.masked_fill(query_responses, ~attention_mask, 0)
output = lm_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
return_dict=True,
output_hidden_states=True,
use_cache=False, # otherwise mistral-based RM would error out
)
reward_logits = model.score(output.hidden_states[-1])
sequence_lengths = first_true_indices(query_responses[:, context_length:] == pad_token_id) - 1 + context_length
# https://github.com/huggingface/transformers/blob/dc68a39c8111217683bf49a4912d0c9018bab33d/src/transformers/models/gpt2/modeling_gpt2.py#L1454
return (
reward_logits,
reward_logits[
torch.arange(reward_logits.size(0), device=reward_logits.device),
sequence_lengths,
].squeeze(-1),
sequence_lengths,
)
model = AutoModelForSequenceClassification.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr")
tokenizer = AutoTokenizer.from_pretrained("cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
dataset = load_dataset(
"vwxyzjn/summarize_from_feedback_tldr_3_filtered_oai_preprocessing_1706381144",
split="validation",
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
with torch.no_grad():
# pad from the right
_, reward1, _ = get_reward(
model, torch.LongTensor(dataset[:4]["query_reference_response_token"]).to(device), tokenizer.pad_token_id, 0
)
# prompt pad from the left, response pad from the right
query_token = torch.LongTensor(dataset[:4]["query_token"]).to(device)
reference_response_token = torch.LongTensor(dataset[:4]["reference_response_token"]).to(device)
query_reference_response_token = torch.cat((query_token, reference_response_token), dim=1)
_, reward2, _ = get_reward(model, query_reference_response_token, tokenizer.pad_token_id, query_token.size(1))
# different batch sizes
_, reward3, _ = get_reward(
model, torch.LongTensor(dataset[:2]["query_reference_response_token"]).to(device), tokenizer.pad_token_id, 0
)
print(f"{reward1=}")
print(f"{reward2=}")
print(f"{reward3=}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment