Skip to content

Instantly share code, notes, and snippets.

@baberabb
Forked from crowsonkb/spo_loss.py
Created June 10, 2024 15:38
Show Gist options
  • Save baberabb/bd3ae044033225e56fe1314f9ceaa783 to your computer and use it in GitHub Desktop.
Save baberabb/bd3ae044033225e56fe1314f9ceaa783 to your computer and use it in GitHub Desktop.
Scalar Preference Optimization
"""Scalar Preference Optimization."""
import torch
from torch.nn import functional as F
def logp_completion(logits, tokens, mask):
"""Compute the log probabilities of completions given their prompts.
Args:
tokens: The tokens input to the model. Shape: (..., T).
logits: The logits output from the model. Shape: (..., T, V).
mask: A mask indicating which tokens should be included in the log probabilities. It should
exclude prompt tokens and padding tokens. Shape: (..., T).
Returns:
The log probabilities of the completions given their prompts. Shape: (...).
"""
logits = F.log_softmax(logits, dim=-1)
logp_tokens = logits[..., :-1, :].gather(-1, tokens[..., 1:, None])[..., 0]
return torch.sum(logp_tokens * mask[..., 1:], dim=-1)
def spo_loss(logp_1, logp_2, logp_ref_1, logp_ref_2, reward_1, reward_2, beta):
"""Compute the Scalar Preference Optimization loss.
The SPO loss takes as input pairs of log probabilities of completions given the
same prompt for each completion in a pair, under the model and a reference model, and scalar
rewards for each completion. It regresses the difference between the model's implied rewards
for the completions toward the difference between their actual rewards, scaled by the inverse
of the KL penalty coefficient.
Args:
logp_1: Log probabilities of the first completions given their prompts under the
model. Should be differentiable w.r.t. the model parameters. Shape: (...).
logp_2: Log probabilities of the second completions given their prompts under the
model. Should be differentiable w.r.t. the model parameters. Shape: (...).
logp_ref_1: Log probabilities of the first completions given their prompts under the
reference model. Shape: (...).
logp_ref_2: Log probabilities of the second completions given their prompts under the
reference model. Shape: (...).
reward_1: Rewards for the first completions. Shape: (...).
reward_2: Rewards for the second completions. Shape: (...).
beta: The KL penalty coefficient (temperature).
Returns:
The Scalar Preference Optimization losses for each pair of completions. Shape: (...).
"""
implied_reward_1 = logp_1 - logp_ref_1
implied_reward_2 = logp_2 - logp_ref_2
return (beta / 4) * ((implied_reward_1 - implied_reward_2) - (reward_1 - reward_2) / beta) ** 2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment