Created
July 14, 2023 22:57
-
-
Save Tony363/f51dcabd1cb034607fe00540b88978ca to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor: | |
"""Compute the log probabilities of the given labels under the given logits. | |
Args: | |
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) | |
labels: Labels for which to compute the log probabilities. Label tokens with a value of -100 are ignored. Shape: (batch_size, sequence_length) | |
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens. | |
Returns: | |
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. | |
""" | |
assert logits.shape[:-1] == labels.shape | |
labels = labels[:, 1:].clone() | |
logits = logits[:, :-1, :] | |
loss_mask = (labels != -100) | |
# dummy token; we'll ignore the losses on these tokens later | |
labels[labels == -100] = 0 | |
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) | |
if average_log_prob: | |
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) | |
else: | |
return (per_token_logps * loss_mask).sum(-1) | |
def dpo_loss(policy_chosen_logps: torch.FloatTensor, | |
policy_rejected_logps: torch.FloatTensor, | |
reference_chosen_logps: torch.FloatTensor, | |
reference_rejected_logps: torch.FloatTensor, | |
beta: float, | |
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: | |
"""Compute the DPO loss for a batch of policy and reference model log probabilities. | |
Args: | |
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) | |
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) | |
reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) | |
reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,) | |
beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0. | |
reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses. | |
Returns: | |
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). | |
The losses tensor contains the DPO loss for each example in the batch. | |
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively. | |
""" | |
pi_logratios = policy_chosen_logps - policy_rejected_logps | |
ref_logratios = reference_chosen_logps - reference_rejected_logps | |
if reference_free: | |
ref_logratios = 0 | |
logits = pi_logratios - ref_logratios | |
losses = -F.logsigmoid(beta * logits) | |
chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach() | |
rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach() | |
return losses, chosen_rewards, rejected_rewards |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment