Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active May 10, 2024 10:08
Show Gist options
  • Save wassname/c6f660f92501a017e8f5792b7a125a3f to your computer and use it in GitHub Desktop.
Save wassname/c6f660f92501a017e8f5792b7a125a3f to your computer and use it in GitHub Desktop.
for huggingface transformers sometime you want to constrain output to json schema and record the probabilities on choices/enums. I use it when rating, judging. It's much more efficient than sampling multiple times.
from jaxtyping import Float, Int
import torch
from torch.nn import functional as F
from torch import Tensor
from typing import List, Callable, Tuple, Dict, Optional
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
def get_valid_next_choices(choices_tokens, current_tokens):
next_choices = []
for choice_tokens in choices_tokens:
# if we have some more slots left
if len(current_tokens) < len(choice_tokens):
# see if current_tokens matches
if (choice_tokens[: len(current_tokens)] == current_tokens).all():
c = choice_tokens[len(current_tokens)].item()
next_choices.append(c)
next_choices = list(set(next_choices))
return torch.LongTensor(next_choices)
def choice_tree(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
input_ids: Int[Tensor, "seq"],
choices_tokens: List[Int[Tensor, "seq"]],
choice: Optional[Int[Tensor, ""]] = None,
prob: float = 1,
current_tokens: Int[Tensor, "seq"] = torch.LongTensor([]),
z=[],
):
if choice is not None:
c = choice[None].to(current_tokens.device)
current_tokens = torch.cat([current_tokens, c], dim=-1)
c = choice[None].to(input_ids.device)
input_ids = torch.cat([input_ids, c], dim=-1)
next_choices = get_valid_next_choices(choices_tokens, current_tokens)
if len(next_choices) == 0:
s = tokenizer.decode(current_tokens)
r = dict(prob=prob, choice=s)
yield r
else:
o = model(input_ids[None])
logits_constrained = o.logits[0, -1][next_choices]
probs = F.softmax(logits_constrained, dim=-1)
for i in range(len(next_choices)):
next_choice = next_choices[i]
next_prob = prob * probs[i].item()
yield from choice_tree(
model=model,
tokenizer=tokenizer,
choices_tokens=choices_tokens,
input_ids=input_ids,
choice=next_choice,
prob=next_prob,
current_tokens=current_tokens,
z=z + [i],
)
@wassname
Copy link
Author

wassname commented May 10, 2024

see https://github.com/wassname/prob_jsonformer.git

from jsonformer import Jsonformer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-12b")
tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-12b")

json_schema = {
    "type": "object",
    "properties": {
        "name": {"type": "string"},
        "age": {"type": "choice_probs", "enum": ["8", "9", "10", "11"]},
        "age2": {"type": "number"},
        "is_student": {"type": "boolean"},
        "is_student2": {"type": "choice_probs", "enum": ["true", "false"]},
        "courses": {
            "type": "array",
            "items": {"type": "string"}
        }
    }
}

prompt = "Generate a person's information based on the following schema:"
jsonformer = Jsonformer(model, tokenizer, json_schema, prompt)
generated_data = jsonformer()

print(generated_data)
# {'name': 'John Doe',
#  'age': [{'prob': 0.1497802734375, 'choice': '8'},
#   {'prob': 0.159423828125, 'choice': '9'},
#   {'prob': 0.0982666015625, 'choice': '11'},
#   {'prob': 0.59228515625, 'choice': '10'}],
#  'age2': 10.0201,
#  'is_student': True,
#  'is_student2': [{'prob': 0.94580078125, 'choice': 'true'},
#   {'prob': 0.05419921875, 'choice': 'false'}],
#  'courses': ['C++']}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment