Skip to content

Instantly share code, notes, and snippets.

@cmrfrd
Created August 13, 2023 06:34
Show Gist options
  • Save cmrfrd/5e0bfa7e5c7d407791ab763acf3000c6 to your computer and use it in GitHub Desktop.
Save cmrfrd/5e0bfa7e5c7d407791ab763acf3000c6 to your computer and use it in GitHub Desktop.
grammar_sampling.py
from typing import Generator
import llama_cpp
import numpy as np
import regex
from lark import Lark, UnexpectedInput, UnexpectedToken
from llama_cpp import Llama
from pydantic import BaseModel, Field
import rstr
from itertools import islice
## Because I'm a plebian to parsers, instead of being a chad
## and find future sequences of terminals from a grammar file
## I get them by "live" constructing parsable text and letting the
## parser state tell me what's next. To do that, we need examples
## of terminals. The rest I use rstr to fill in.
COMMON_TOKEN_MAP = {
"ESCAPED_STRING": '" "',
"SIGNED_NUMBER": "123",
}
def is_ascii_decodable(s: str) -> bool:
"""Return whether the given string is ascii.
Args:
s (str): The string to check.
Returns:
bool: Whether the given string is ascii.
"""
try:
s.encode(encoding="utf-8").decode("ascii")
except UnicodeDecodeError:
return False
else:
return True
def build_token_map(llm: Llama) -> dict[int, str]:
"""Return a map of tokenid to token string for the given Llama instance.
Args:
llm (Llama): The Llama instance to build the token map for.
Returns:
dict[int, str]: A map of tokenid to token string.
"""
token_map: dict[int, str] = {}
for i in range(llama_cpp.llama_n_vocab(llm.ctx)):
val = llama_cpp.llama_token_to_str(llm.ctx, i).decode("utf-8", errors="ignore")
token_map[i] = val
return token_map
class TokenFilter(BaseModel):
"""
TokenFilter is a class that can be used to filter tokens by regex patterns. It is used to
filter tokens that are partially parsable by the grammar model.
"""
token_map: dict[int, str] = Field(..., description="A map of tokenid to token string.")
patterns: list[regex.Pattern] = Field(
..., description="A list of regex patterns to filter tokens by."
)
def is_partial_parsable_token(self, token_id: int, partial_completion: str) -> bool:
"""Return whether the given token is partially parsable given the regex patterns.
If the token value is empty, then it is considered parsable.
Args:
token_id (int): The token id to check.
partial_completion (str): The partial completion to check against.
Returns:
bool: Whether the given token is partially parsable given the regex patterns.
"""
token_val = self.token_map[token_id]
if not is_ascii_decodable(token_val):
return False
for pattern in self.patterns:
match = pattern.fullmatch(partial_completion + token_val, partial=True)
if not match:
break
if match.span() == (0, 0):
break
return True
return False
def filter_partial_parsable_tokens(self, partial_completion: str) -> set[int]:
"""Return a set of token ids that are partially parsable given the regex patterns.
Args:
partial_completion (str): The partial completion to check against.
Returns:
set[int]: A set of token ids that are partially parsable given the regex patterns.
"""
result = set(
filter(
lambda token_id: self.is_partial_parsable_token(token_id, partial_completion),
self.token_map.keys(),
)
)
return result
class Config:
arbitrary_types_allowed = True
class LogitMask(BaseModel):
"""Zeroes out non non parsable tokens."""
partial_parsable_tokens: set[int] = Field(
...,
description="A set of token ids that are partially parsable given the regex patterns.",
min_length=1,
)
def __call__(self, input_ids: list[int], scores: list[float]) -> list[float]:
"""Return a list of scores with the partial parsable tokens masked.
Args:
input_ids (list[int]): The input ids.
scores (list[float]): The scores.
Returns:
list[float]: A list of scores with the partial parsable tokens masked.
"""
mask = np.ones_like(scores) * -1e10
partial_parsable_tokens = np.array(list(self.partial_parsable_tokens))
mask[partial_parsable_tokens] = 0.0
return_scores = np.array(scores) + mask
return list(return_scores.tolist())
def generate_from_regular_expressions(
llm: Llama, prompt: str, patterns: list[regex.Pattern], max_tokens: int = 5, **kwargs: dict
) -> str:
"""Complete a prompt with the output matching the regex(es).
Returns when the output fully matches the regex(es) or when max_tokens is reached.
Args:
llm (Llama): The Llama instance to use.
prompt (str): The prompt to complete.
patterns (list[regex.Pattern]): The regex patterns to use.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 5.
kwargs: Additional arguments to pass to Llama.generate.
Returns:
str: The completed prompt.
"""
partial_completion: str = ""
prompt_plus_completion: str = prompt + partial_completion
token_map = build_token_map(llm)
gen_tokens = 0
while gen_tokens < max_tokens:
token_filter = TokenFilter(
token_map=token_map,
patterns=patterns,
)
logit_mask = LogitMask(
partial_parsable_tokens=token_filter.filter_partial_parsable_tokens(partial_completion),
)
next_tokens = list(
islice(
llm.generate(
llm.tokenize(prompt_plus_completion.encode("utf-8")),
logits_processor=logit_mask,
**kwargs,
),
1,
)
)
output_text = llm.detokenize(next_tokens).decode("utf-8", errors="ignore")
previous_partial_completion = partial_completion
partial_completion += output_text
prompt_plus_completion = prompt_plus_completion + output_text
# match atleast 3 patterns
matches = [p.match(partial_completion) for p in patterns]
has_match = len([m for m in matches if m]) >= 3
if has_match:
first_match = next(m for m in matches if m)
if first_match.start() == 0 and first_match.end() == (len(partial_completion)):
return str(first_match[0])
gen_tokens += 1
return partial_completion
class ParserState(BaseModel):
"""A class to hold the state of the parser to determine next parsable tokens."""
parser: Lark = Field(..., description="The Lark parser.")
def next_lexables(
self, text: str, n: int = 3, tokens: list | None = None
) -> Generator[list[str], None, None]:
"""Return all possible next 'n' lexables given a piece of text
Args:
text: input text to parse
n: number of tokens to look ahead
tokens: already seen tokens
"""
if tokens is None:
tokens = []
try:
pi = self.parser.parse_interactive(text)
pi.exhaust_lexer()
except UnexpectedToken:
return
for token in list(pi.accepts()): # type: ignore
tokens += [token]
if token == "$END":
yield list(tokens)
tokens.pop(-1)
continue
elif n >= 0:
yield list(tokens)
else:
tokens.pop(-1)
return
if token in COMMON_TOKEN_MAP:
gen = COMMON_TOKEN_MAP[token]
else:
gen = rstr.xeger(self.parser.get_terminal(name=token).pattern.to_regexp())
yield from self.next_lexables(
text + gen,
n - 1,
tokens,
)
tokens.pop(-1)
def extract_terminal_regex(self, stop_token: str) -> dict[str, regex.Pattern]:
"""Return a map of terminal name to regex pattern.
Args:
stop_token (str): The stop token to use.
Returns:
dict[str, regex.Pattern]: A map of terminal name to regex pattern.
"""
regex_map = {}
for term in self.parser.terminals:
if term.pattern:
regex_map[term.name] = term.pattern.to_regexp()
regex_map["$END"] = stop_token
return regex_map
class Config:
arbitrary_types_allowed = True
def generate_from_cfg(
llm: Llama, prompt: str, parser: Lark, max_tokens: int = 5, **kwargs: dict
) -> Generator[str, None, None]:
"""
Complete a prompt with a regex pattern.
Args:
llm (Llama): The Llama instance to use.
prompt (str): The prompt to complete.
parser (Lark): The Lark parser to use.
max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 5.
kwargs: Additional arguments to pass to Llama.generate.
Yields:
Generator[str, None, None]: A generator of completed prompts.
"""
partial_completion: str = ""
prompt_plus_completion = prompt + partial_completion
parser_state = ParserState(parser=parser)
terminal_regexes = parser_state.extract_terminal_regex(
llm.detokenize([llm._token_eos]).decode("utf-8") # pylint: disable=protected-access
)
gen_tokens = 0
while gen_tokens < max_tokens:
valid_next_lexables = list(parser_state.next_lexables(partial_completion))
# print(f"valid nxt lxbls: {valid_next_lexables}")
if len(valid_next_lexables) == 0:
break
elif len(valid_next_lexables) == 1:
if len(valid_next_lexables[0]) == 1:
if valid_next_lexables[0][0] == "$END":
break
## flatten to single list of regex's
regexes = [
regex.compile("".join(terminal_regexes[t] for t in lexables))
for lexables in valid_next_lexables
]
# print(f"valid nxt regexes: {regexes}")
next_token_completion = generate_from_regular_expressions(
llm, prompt_plus_completion, regexes, max_tokens=max_tokens, **kwargs
)
yield next_token_completion
partial_completion += next_token_completion
prompt_plus_completion = prompt_plus_completion + next_token_completion
gen_tokens += 1
import json
from lark import Lark
from llama_cpp import Llama
from grammar_llm.utils import generate_from_cfg
# llm = Llama(model_path="./models/llama-2-13b.ggmlv3.q4_0.bin", verbose=False)
llm = Llama(model_path="./models/vicuna-13b-v1.5.ggmlv3.q8_0.bin", verbose=False)
json_grammar = r"""
?start: json_schema
json_schema: "{" name ", " age ", " height ", " location "}"
name: "\"name\"" ": " ESCAPED_STRING
age: "\"age\"" ": " NUMBER
height: "\"height\"" ": " NUMBER
location: "\"location\"" ": " ESCAPED_STRING
%import common.ESCAPED_STRING
%import common.NUMBER
%import common.WS
%ignore WS
"""
### Create the JSON parser with Lark, using the LALR algorithm
json_parser = Lark(
json_grammar,
parser="lalr",
lexer="basic",
propagate_positions=False,
maybe_placeholders=False,
regex=True,
)
prompt = """
You are an AI assistant that takes raw text data and generates json with the following grammer:
```
{json_grammar}
```
Input: "Alex is 27 years old, 72 inches tall, and lives in Miami."
Output: """
out = ""
for t in generate_from_cfg(llm, prompt, json_parser, max_tokens=50):
print(t, end="")
out += t
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment