Skip to content

Instantly share code, notes, and snippets.

@903124
Created February 6, 2025 16:58
Show Gist options
  • Save 903124/fa00ed81a8e8b973a823e3c01e82ffd6 to your computer and use it in GitHub Desktop.
Save 903124/fa00ed81a8e8b973a823e3c01e82ffd6 to your computer and use it in GitHub Desktop.
# outlines/processor/structured.py
...
class GuideLogitsProcessor(OutlinesLogitsProcessor):
"""Bias generation using a finite
Attributes
----------
tokenizer
The tokenizer used to convert tokens to ids.
guide
The `outlines.fsm.Guide` which is used to bias the logits.
"""
tokenizer: "Tokenizer"
guide: Guide
_guide_states: Dict[int, Any]
_seq_start_idx: Optional[int]
def __init__(self, tokenizer: "Tokenizer", guide: Guide):
"""A Guide-based logits processor.
Parameters
----------
tokenizer
The tokenizer used to convert tokens to ids.
guide
The `outlines.fsm.Guide. which is used to bias the logits.
"""
self.tokenizer = tokenizer
self.guide = guide
self._guide_states = {hash(tuple([])): self.guide.initial_state}
self._seq_start_idx = None
def process_logits(
self, input_ids: torch.LongTensor, logits: torch.FloatTensor
) -> torch.Tensor:
if self._seq_start_idx is None:
self._seq_start_idx = len(input_ids[0])
sequence_states: List[int] = []
for seq_ids in input_ids:
gen_ids = seq_ids[self._seq_start_idx :]
curr_state_key = hash(tuple(gen_ids.tolist()))
if curr_state_key not in self._guide_states:
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))]
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item())
self._guide_states[curr_state_key] = curr_state
sequence_states.append(self._guide_states[curr_state_key])
allowed_tokens_batch = []
batch_indices = []
for i, guide_state in enumerate(sequence_states):
instruction = self.guide.get_next_instruction(guide_state, input_ids[i:i+1])
allowed_tokens = instruction.tokens
if allowed_tokens is not None:
# Convert to LongTensor and ensure device matching
allowed_tokens = torch.tensor(allowed_tokens, dtype=torch.long, device=logits.device)
allowed_tokens_batch.append(allowed_tokens)
batch_indices.append(torch.full_like(allowed_tokens, i, dtype=torch.long))
if not allowed_tokens_batch: # Handle case where no allowed tokens
return torch.full_like(logits, float("-inf"))
allowed_tokens_concat = torch.cat(allowed_tokens_batch)
batch_indices_concat = torch.cat(batch_indices)
# Create mask and ensure proper types
mask = torch.ones_like(logits, dtype=torch.bool)
mask[batch_indices_concat, allowed_tokens_concat] = False
logits.masked_fill_(mask, float("-inf"))
return logits
def copy(self) -> "GuideLogitsProcessor":
"""Return a copy of the logits processor."""
return GuideLogitsProcessor(tokenizer=self.tokenizer, guide=self.guide.copy())
...
# outlines_core/fsm/guide.py
class EOSContinueRegexGuide(RegexGuide):
"""Guide that continues generation after EOS or when next state is None"""
def __init__(
self,
states_to_token_maps,
empty_token_ids,
eos_tensor,
initial_state,
continue_text: str,
max_tries: int = 1
):
super().__init__(states_to_token_maps, empty_token_ids, eos_tensor, initial_state)
self.continue_text = continue_text
self.max_tries = max_tries
self.current_tries = 0
self._tokenizer = None
self.needs_continuation = False
@classmethod
def from_regex(cls, regex_string: str, tokenizer, continue_text: str, max_tries: int = 1, device=None):
states_to_token_maps, empty_token_ids, fsm_finals = create_states_mapping(
regex_string, tokenizer
)
eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device)
initial_state = states_to_token_maps.get_initial_state()
guide = cls(
states_to_token_maps,
empty_token_ids,
eos_tensor,
initial_state,
continue_text=continue_text,
max_tries=max_tries
)
guide._tokenizer = tokenizer
return guide
def _try_continuation(self, input_ids: Optional[torch.Tensor] = None) -> Optional[Instruction]:
"""Handle continuation logic while preserving context"""
if self.current_tries < self.max_tries:
# Create new regex for continuation
continue_regex = f"{re.escape(self.continue_text)}[A-Za-z\\s]+\\."
# Create new FSM for continuation
new_states_map, new_empty_ids, _ = create_states_mapping(
continue_regex,
self._tokenizer
)
# Update guide state
self.states_to_token_maps = new_states_map
self.empty_token_ids = new_empty_ids
self.initial_state = new_states_map.get_initial_state()
self.current_tries += 1
# Decode and store the generated text so far
if input_ids is not None:
self.generated_text = self._tokenizer.decode(input_ids[0])
# Get continuation tokens
continue_tokens = self._tokenizer.encode(self.continue_text)[0]
if isinstance(continue_tokens, torch.Tensor):
continue_tokens = continue_tokens.tolist()
# For the first token after completion, ONLY allow the first token of continue_text
return Generate([continue_tokens[0]])
def get_next_state(self, state: int, token_id: int) -> int:
"""Override to handle None next_state as continuation trigger"""
if state == -1:
return -1
next_state = self.states_to_token_maps.get_next_state(state, token_id)
if next_state is None and self.current_tries < self.max_tries:
# Reset to initial state for continuation
return self.initial_state
return -1 if next_state is None else next_state
def get_next_instruction(self, state: int, input_ids: Optional[torch.Tensor] = None) -> Instruction:
"""Handle both EOS and None next_state cases with context"""
if state == -1:
continuation = self._try_continuation(input_ids)
if continuation:
return continuation
return Write(self.eos_tensor)
next_tokens_mask = self.states_to_token_maps.get_allowed_tokens(state)
if next_tokens_mask is None:
continuation = self._try_continuation(input_ids)
if continuation:
return continuation
return Write(self.eos_tensor)
return Generate(torch.tensor(next_tokens_mask))
def copy(self):
copied = EOSContinueRegexGuide(
self.states_to_token_maps,
self.empty_token_ids,
self.eos_tensor,
self.initial_state,
self.continue_text,
self.max_tries
)
copied.current_tries = self.current_tries
copied._tokenizer = self._tokenizer
return copied
#Inference
import outlines
from transformers import AutoTokenizer
from outlines_core.fsm.guide import EOSContinueRegexGuide
from outlines.processors.structured import GuideLogitsProcessor
# 1. Using with LlamaCpp model
model = outlines.models.llamacpp(
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf"
)
guide = EOSContinueRegexGuide.from_regex(
regex_string=r"[A-Za-z\s]+\.",
tokenizer=model.tokenizer,
continue_text=" But wait,",
max_tries=3
)
# Create processor with guide
processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide)
# Create generator and set processor
generator = outlines.generate.text(model)
generator.logits_processor = processor
# # Use the generator
prompt = """<|im_start|>system You are a helpful assistant.
<|im_end|>
<|im_start|>user
How many r in the word starberry?
<|im_end|>
<|im_start|>assistant"""
structured = generator.stream(prompt, max_tokens=1000)
for chunk in structured:
print(chunk, end="", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment