Created
February 6, 2025 16:58
-
-
Save 903124/fa00ed81a8e8b973a823e3c01e82ffd6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# outlines/processor/structured.py | |
... | |
class GuideLogitsProcessor(OutlinesLogitsProcessor): | |
"""Bias generation using a finite | |
Attributes | |
---------- | |
tokenizer | |
The tokenizer used to convert tokens to ids. | |
guide | |
The `outlines.fsm.Guide` which is used to bias the logits. | |
""" | |
tokenizer: "Tokenizer" | |
guide: Guide | |
_guide_states: Dict[int, Any] | |
_seq_start_idx: Optional[int] | |
def __init__(self, tokenizer: "Tokenizer", guide: Guide): | |
"""A Guide-based logits processor. | |
Parameters | |
---------- | |
tokenizer | |
The tokenizer used to convert tokens to ids. | |
guide | |
The `outlines.fsm.Guide. which is used to bias the logits. | |
""" | |
self.tokenizer = tokenizer | |
self.guide = guide | |
self._guide_states = {hash(tuple([])): self.guide.initial_state} | |
self._seq_start_idx = None | |
def process_logits( | |
self, input_ids: torch.LongTensor, logits: torch.FloatTensor | |
) -> torch.Tensor: | |
if self._seq_start_idx is None: | |
self._seq_start_idx = len(input_ids[0]) | |
sequence_states: List[int] = [] | |
for seq_ids in input_ids: | |
gen_ids = seq_ids[self._seq_start_idx :] | |
curr_state_key = hash(tuple(gen_ids.tolist())) | |
if curr_state_key not in self._guide_states: | |
prev_state = self._guide_states[hash(tuple(gen_ids[:-1].tolist()))] | |
curr_state = self.guide.get_next_state(prev_state, gen_ids[-1].item()) | |
self._guide_states[curr_state_key] = curr_state | |
sequence_states.append(self._guide_states[curr_state_key]) | |
allowed_tokens_batch = [] | |
batch_indices = [] | |
for i, guide_state in enumerate(sequence_states): | |
instruction = self.guide.get_next_instruction(guide_state, input_ids[i:i+1]) | |
allowed_tokens = instruction.tokens | |
if allowed_tokens is not None: | |
# Convert to LongTensor and ensure device matching | |
allowed_tokens = torch.tensor(allowed_tokens, dtype=torch.long, device=logits.device) | |
allowed_tokens_batch.append(allowed_tokens) | |
batch_indices.append(torch.full_like(allowed_tokens, i, dtype=torch.long)) | |
if not allowed_tokens_batch: # Handle case where no allowed tokens | |
return torch.full_like(logits, float("-inf")) | |
allowed_tokens_concat = torch.cat(allowed_tokens_batch) | |
batch_indices_concat = torch.cat(batch_indices) | |
# Create mask and ensure proper types | |
mask = torch.ones_like(logits, dtype=torch.bool) | |
mask[batch_indices_concat, allowed_tokens_concat] = False | |
logits.masked_fill_(mask, float("-inf")) | |
return logits | |
def copy(self) -> "GuideLogitsProcessor": | |
"""Return a copy of the logits processor.""" | |
return GuideLogitsProcessor(tokenizer=self.tokenizer, guide=self.guide.copy()) | |
... | |
# outlines_core/fsm/guide.py | |
class EOSContinueRegexGuide(RegexGuide): | |
"""Guide that continues generation after EOS or when next state is None""" | |
def __init__( | |
self, | |
states_to_token_maps, | |
empty_token_ids, | |
eos_tensor, | |
initial_state, | |
continue_text: str, | |
max_tries: int = 1 | |
): | |
super().__init__(states_to_token_maps, empty_token_ids, eos_tensor, initial_state) | |
self.continue_text = continue_text | |
self.max_tries = max_tries | |
self.current_tries = 0 | |
self._tokenizer = None | |
self.needs_continuation = False | |
@classmethod | |
def from_regex(cls, regex_string: str, tokenizer, continue_text: str, max_tries: int = 1, device=None): | |
states_to_token_maps, empty_token_ids, fsm_finals = create_states_mapping( | |
regex_string, tokenizer | |
) | |
eos_tensor = torch.tensor([tokenizer.eos_token_id], device=device) | |
initial_state = states_to_token_maps.get_initial_state() | |
guide = cls( | |
states_to_token_maps, | |
empty_token_ids, | |
eos_tensor, | |
initial_state, | |
continue_text=continue_text, | |
max_tries=max_tries | |
) | |
guide._tokenizer = tokenizer | |
return guide | |
def _try_continuation(self, input_ids: Optional[torch.Tensor] = None) -> Optional[Instruction]: | |
"""Handle continuation logic while preserving context""" | |
if self.current_tries < self.max_tries: | |
# Create new regex for continuation | |
continue_regex = f"{re.escape(self.continue_text)}[A-Za-z\\s]+\\." | |
# Create new FSM for continuation | |
new_states_map, new_empty_ids, _ = create_states_mapping( | |
continue_regex, | |
self._tokenizer | |
) | |
# Update guide state | |
self.states_to_token_maps = new_states_map | |
self.empty_token_ids = new_empty_ids | |
self.initial_state = new_states_map.get_initial_state() | |
self.current_tries += 1 | |
# Decode and store the generated text so far | |
if input_ids is not None: | |
self.generated_text = self._tokenizer.decode(input_ids[0]) | |
# Get continuation tokens | |
continue_tokens = self._tokenizer.encode(self.continue_text)[0] | |
if isinstance(continue_tokens, torch.Tensor): | |
continue_tokens = continue_tokens.tolist() | |
# For the first token after completion, ONLY allow the first token of continue_text | |
return Generate([continue_tokens[0]]) | |
def get_next_state(self, state: int, token_id: int) -> int: | |
"""Override to handle None next_state as continuation trigger""" | |
if state == -1: | |
return -1 | |
next_state = self.states_to_token_maps.get_next_state(state, token_id) | |
if next_state is None and self.current_tries < self.max_tries: | |
# Reset to initial state for continuation | |
return self.initial_state | |
return -1 if next_state is None else next_state | |
def get_next_instruction(self, state: int, input_ids: Optional[torch.Tensor] = None) -> Instruction: | |
"""Handle both EOS and None next_state cases with context""" | |
if state == -1: | |
continuation = self._try_continuation(input_ids) | |
if continuation: | |
return continuation | |
return Write(self.eos_tensor) | |
next_tokens_mask = self.states_to_token_maps.get_allowed_tokens(state) | |
if next_tokens_mask is None: | |
continuation = self._try_continuation(input_ids) | |
if continuation: | |
return continuation | |
return Write(self.eos_tensor) | |
return Generate(torch.tensor(next_tokens_mask)) | |
def copy(self): | |
copied = EOSContinueRegexGuide( | |
self.states_to_token_maps, | |
self.empty_token_ids, | |
self.eos_tensor, | |
self.initial_state, | |
self.continue_text, | |
self.max_tries | |
) | |
copied.current_tries = self.current_tries | |
copied._tokenizer = self._tokenizer | |
return copied | |
#Inference | |
import outlines | |
from transformers import AutoTokenizer | |
from outlines_core.fsm.guide import EOSContinueRegexGuide | |
from outlines.processors.structured import GuideLogitsProcessor | |
# 1. Using with LlamaCpp model | |
model = outlines.models.llamacpp( | |
"microsoft/Phi-3-mini-4k-instruct-gguf", | |
"Phi-3-mini-4k-instruct-q4.gguf" | |
) | |
guide = EOSContinueRegexGuide.from_regex( | |
regex_string=r"[A-Za-z\s]+\.", | |
tokenizer=model.tokenizer, | |
continue_text=" But wait,", | |
max_tries=3 | |
) | |
# Create processor with guide | |
processor = GuideLogitsProcessor(tokenizer=model.tokenizer, guide=guide) | |
# Create generator and set processor | |
generator = outlines.generate.text(model) | |
generator.logits_processor = processor | |
# # Use the generator | |
prompt = """<|im_start|>system You are a helpful assistant. | |
<|im_end|> | |
<|im_start|>user | |
How many r in the word starberry? | |
<|im_end|> | |
<|im_start|>assistant""" | |
structured = generator.stream(prompt, max_tokens=1000) | |
for chunk in structured: | |
print(chunk, end="", flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment