Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Last active November 17, 2023 23:55
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save brandonwillard/e1f41053c599bb584d4b922251cd96f5 to your computer and use it in GitHub Desktop.
Save brandonwillard/e1f41053c599bb584d4b922251cd96f5 to your computer and use it in GitHub Desktop.
Computing sequence probabilities in Outlines
import torch
import outlines.models as models
from outlines.text.generate.regex import choice
from outlines.text.generate.continuation import continuation
from outlines.text.generate.sample import greedy
def make_greedy_tracker(generator):
import types
generator.last_sequence_log_prob = None
generator.running_sequence_log_prob = 0.0
def tracking_greedy(
logits: torch.DoubleTensor, samples: int, *_
) -> torch.DoubleTensor:
next_token_ids = greedy(logits, samples)
probs = torch.nn.functional.softmax(logits, dim=-1)
generator.running_sequence_log_prob += torch.log(
probs[:, next_token_ids.squeeze()].squeeze()
)
return next_token_ids
generator.sampler = tracking_greedy
old_postprocess_completions = generator.postprocess_completions
def new_postprocess_completions(self, *args, **kwargs):
# Reset the sequence log-probability
res = old_postprocess_completions(*args, **kwargs)
self.last_sequence_log_prob = self.running_sequence_log_prob
self.running_sequence_log_prob = 0.0
return res
generator.postprocess_completions = types.MethodType(
new_postprocess_completions, generator
)
return generator
model = models.transformers("gpt2")
generator = make_greedy_tracker(continuation(model, max_tokens=50))
choice_generator = make_greedy_tracker(
choice(model, ["[Bb]lue", "[Rr]ed"], max_tokens=50)
)
prompt = "Which color do you prefer: blue or red?"
sequence = generator(prompt)
print(sequence)
#
#
# The answer is blue.
#
# The color of the car is the color of the car.
#
# The color of the car is the color of the car.
#
# The color of the car is the color of the car.
print(generator.last_sequence_log_prob)
# tensor(-44.7725)
sequence = generator("Which color do you prefer: red or blue?")
print(sequence)
# The answer is: red.
#
# The red color is the color of the color of the car. It's the color of the car that's the most important.
#
# The red color is the color of the car that's the
print(generator.last_sequence_log_prob)
# tensor(-69.0348)
sequence = choice_generator(prompt)
print(sequence)
# Blue
print(choice_generator.last_sequence_log_prob)
# tensor(-0.9262)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment