Last active
June 10, 2024 15:38
-
-
Save crowsonkb/0bd5169b9d0e253a09f1876f3bdf79a6 to your computer and use it in GitHub Desktop.
Scalar Preference Optimization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Scalar Preference Optimization.""" | |
import torch | |
from torch.nn import functional as F | |
def logp_completion(logits, tokens, mask): | |
"""Compute the log probabilities of completions given their prompts. | |
Args: | |
tokens: The tokens input to the model. Shape: (..., T). | |
logits: The logits output from the model. Shape: (..., T, V). | |
mask: A mask indicating which tokens should be included in the log probabilities. It should | |
exclude prompt tokens and padding tokens. Shape: (..., T). | |
Returns: | |
The log probabilities of the completions given their prompts. Shape: (...). | |
""" | |
logits = F.log_softmax(logits, dim=-1) | |
logp_tokens = logits[..., :-1, :].gather(-1, tokens[..., 1:, None])[..., 0] | |
return torch.sum(logp_tokens * mask[..., 1:], dim=-1) | |
def spo_loss(logp_1, logp_2, logp_ref_1, logp_ref_2, reward_1, reward_2, beta): | |
"""Compute the Scalar Preference Optimization loss. | |
The SPO loss takes as input pairs of log probabilities of completions given the | |
same prompt for each completion in a pair, under the model and a reference model, and scalar | |
rewards for each completion. It regresses the difference between the model's implied rewards | |
for the completions toward the difference between their actual rewards, scaled by the inverse | |
of the KL penalty coefficient. | |
Args: | |
logp_1: Log probabilities of the first completions given their prompts under the | |
model. Should be differentiable w.r.t. the model parameters. Shape: (...). | |
logp_2: Log probabilities of the second completions given their prompts under the | |
model. Should be differentiable w.r.t. the model parameters. Shape: (...). | |
logp_ref_1: Log probabilities of the first completions given their prompts under the | |
reference model. Shape: (...). | |
logp_ref_2: Log probabilities of the second completions given their prompts under the | |
reference model. Shape: (...). | |
reward_1: Rewards for the first completions. Shape: (...). | |
reward_2: Rewards for the second completions. Shape: (...). | |
beta: The KL penalty coefficient (temperature). | |
Returns: | |
The Scalar Preference Optimization losses for each pair of completions. Shape: (...). | |
""" | |
implied_reward_1 = logp_1 - logp_ref_1 | |
implied_reward_2 = logp_2 - logp_ref_2 | |
return (beta / 4) * ((implied_reward_1 - implied_reward_2) - (reward_1 - reward_2) / beta) ** 2 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment