Skip to content

Instantly share code, notes, and snippets.

@JoshC8C7
Created October 9, 2024 10:06
Show Gist options
  • Save JoshC8C7/b22f94e5c10b4b4387a02d56db49ea97 to your computer and use it in GitHub Desktop.
Save JoshC8C7/b22f94e5c10b4b4387a02d56db49ea97 to your computer and use it in GitHub Desktop.
json_schema = """{"type": "object","properties": {"name": {"title": "name","type": "string", "pattern":"Łukas"}}, "required": ["name"]}"""
parser = JsonSchemaParser(json.loads(json_schema))
model_id = "meta-llama/Llama-3.2-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(["what is lucas in polish"], return_tensors="pt", padding=True).to("cuda")
prefix_function = build_transformers_prefix_allowed_tokens_fn(tokenizer, parser)
outputs = model.generate(**inputs, prefix_allowed_tokens_fn=prefix_function, max_new_tokens=200)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment