Skip to content

Instantly share code, notes, and snippets.

@BorisTheBrave
Last active February 12, 2023 17:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BorisTheBrave/969f303a082c9da1916d04ee1eb04452 to your computer and use it in GitHub Desktop.
Save BorisTheBrave/969f303a082c9da1916d04ee1eb04452 to your computer and use it in GitHub Desktop.
constrainted-language-models.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/BorisTheBrave/969f303a082c9da1916d04ee1eb04452/constrainted-language-models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"This is some sample code for the article [**Constrainted Text Generation with AI**](https://www.boristhebrave.com/2023/02/11/constrained-text-generation-with-ai/)."
],
"metadata": {
"id": "K62fFqAZhpH3"
}
},
{
"cell_type": "markdown",
"source": [
"# Prolog\n",
"\n",
"Standard setup - install transforms and create a gpt model"
],
"metadata": {
"id": "0PaXNwPC0oyk"
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "xayKyPnc33oo"
}
},
{
"cell_type": "code",
"source": [
"!pip install -q git+https://github.com/huggingface/transformers.git"
],
"metadata": {
"id": "cZv8Gwja8PqJ",
"outputId": "590999ba-4b75-45df-b48b-251ec5d61182",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
"model_name = \"gpt2\"\n",
"tokenizer = GPT2Tokenizer.from_pretrained(model_name)\n",
"model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id)\n",
"\n",
"#from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
"#tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"#model = AutoModelForSequenceClassification.from_pretrained(\"bert-base-uncase\")"
],
"metadata": {
"id": "sETCgWwapW0d"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def generate_and_print(input, **kwargs):\n",
" input_ids = tokenizer(input, return_tensors=\"pt\").input_ids\n",
" output = model.generate(\n",
" input_ids,\n",
" return_dict_in_generate=True,\n",
" output_scores=True,\n",
" **kwargs\n",
" )\n",
" print(repr(input))\n",
" print(f\"Output:\\n\" + 100 * '-')\n",
" for seq, score in zip(output.sequences, output.sequences_scores):\n",
" decoded = tokenizer.decode(seq, skip_special_tokens=True)[len(input):]\n",
" print(f\"score {score.item():.2F}: {repr(decoded)}\")"
],
"metadata": {
"id": "iGkrNW4Z00O2"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Unmodified"
],
"metadata": {
"id": "DKgUjtv13dXj"
}
},
{
"cell_type": "code",
"source": [
"input_str = \"\"\"How many quarts in a gallon?\"\"\"\n",
"\n",
"generate_and_print(\n",
" input_str,\n",
" num_beams=10,\n",
" num_return_sequences=5,\n",
" no_repeat_ngram_size=3,\n",
" max_new_tokens=10,\n",
" remove_invalid_values=True,\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Lr-2geYv3Edd",
"outputId": "dd594cf4-1e40-4f2a-e404-9155909da86e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"'How many quarts in a gallon?'\n",
"Output:\n",
"----------------------------------------------------------------------------------------------------\n",
"score -0.67: '\\n\\nIf you want to know how many qu'\n",
"score -0.68: '\\n\\nHow many gallons of water do you need'\n",
"score -0.71: '\\n\\nIt depends on the size of the qu'\n",
"score -0.72: '\\n\\nIt depends on how much water you use'\n",
"score -0.73: '\\n\\nIt depends on the size of the container'\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"With no fine-tuning or guidance, GPT does not like answering questions. It often follows a question with another question, or \"it depends\"."
],
"metadata": {
"id": "tUzYfN_I50z_"
}
},
{
"cell_type": "markdown",
"source": [
"# Fixed phrases\n",
"\n",
"Using `prefix_allowed_tokens_fn`, we can constrain which tokens are selected next while generating, to force the output to behave as we want.\n",
"\n",
"The generator will search for the best output under these constraints.\n",
"\n",
"For these small lists, search is kinda pointless as we could just as ask the model of the probability of every phrase directly. But you could imagine a trie of phrases being explored."
],
"metadata": {
"id": "0ghCknTI4AAk"
}
},
{
"cell_type": "code",
"source": [
"vocab = tokenizer.get_vocab()\n",
"all_inputs = list(vocab.values())\n",
"\n",
"def get_inputs(f):\n",
" \"\"\"Gets all input ids with tokens matching a function\"\"\"\n",
" return [v for (k,v) in vocab.items() if f(k)]\n",
"\n",
"def normalize(t):\n",
" \"\"\"Normalize tokens\"\"\"\n",
" return t.replace(\"Ġ\", \"\").lower()\n",
"\n",
"def yes_or_no(input_str):\n",
" \"\"\"Only allows answers to start with yes or no\"\"\"\n",
" input_len = len(tokenizer.encode(input_str))\n",
" yes_or_no_inputs = get_inputs(lambda t: normalize(t) in ('yes', 'no'))\n",
" def prefix_allowed_tokens(batchId, inputIds):\n",
" decoded = tokenizer.decode(inputIds)\n",
" if len(inputIds) > input_len:\n",
" return all_inputs\n",
" return yes_or_no_inputs\n",
" return prefix_allowed_tokens\n",
"\n",
"allowed_phrases = [\"My name is Bob.\", \"My name is Alice.\", \"Yes\", \"No\", \"13\"]\n",
"def restrict_phrases(input_str):\n",
" \"\"\"Restricts the answer to a fixed set of allowed phrases\"\"\"\n",
" # This implementation is inefficient, but there's plenty of obvious ways to speed it up\n",
" # https://discuss.huggingface.co/t/example-of-prefix-allowed-tokens-fn-while-text-generation/6635/2\n",
" def prefix_allowed_tokens(batchId, inputIds):\n",
" # Get the answer so far\n",
" decoded = tokenizer.decode(inputIds)[len(input_str):]\n",
" # How could we continue this into a phrase\n",
" phrases = [p[len(decoded):] for p in allowed_phrases if p.startswith(decoded)]\n",
" # What token comes next?\n",
" next_tokens = set(tokenizer.encode(p)[0] if p else tokenizer.eos_token_id for p in phrases)\n",
" return list(next_tokens)\n",
" return prefix_allowed_tokens\n"
],
"metadata": {
"id": "w0PqYI5Xqeet"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for input_str in [\"What is your name?\", \"Is Everest a mountain?\", \"How many quarts in a gallon?\"]:\n",
" generate_and_print(\n",
" input_str,\n",
" num_beams=10,\n",
" num_return_sequences=5,\n",
" no_repeat_ngram_size=3,\n",
" max_new_tokens=10,\n",
" remove_invalid_values=True,\n",
" prefix_allowed_tokens_fn=restrict_phrases(input_str),\n",
" )\n",
" print()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zT_v61nb1Rgm",
"outputId": "510b47fa-97a2-4438-9137-f6df4f57d0cb"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"'What is your name?'\n",
"Output:\n",
"----------------------------------------------------------------------------------------------------\n",
"score -2.49: 'My name is Bob.'\n",
"score -2.55: 'My name is Alice.'\n",
"score -3.05: 'Yes'\n",
"score -3.12: '13'\n",
"score -3.15: 'No'\n",
"\n",
"'Is Everest a mountain?'\n",
"Output:\n",
"----------------------------------------------------------------------------------------------------\n",
"score -2.73: 'No'\n",
"score -2.79: 'Yes'\n",
"score -3.09: 'My name is Bob.'\n",
"score -3.13: '13'\n",
"score -3.21: 'My name is Alice.'\n",
"\n",
"'How many quarts in a gallon?'\n",
"Output:\n",
"----------------------------------------------------------------------------------------------------\n",
"score -2.06: '13'\n",
"score -2.21: 'No'\n",
"score -2.24: 'Yes'\n",
"score -2.55: 'My name is Bob.'\n",
"score -2.68: 'My name is Alice.'\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
":The output is now always one of the specified phrases. It prefers phrases that actually answer the question."
],
"metadata": {
"id": "QMquR3E055AH"
}
},
{
"cell_type": "markdown",
"source": [
"# Open constraints\n",
"\n",
"The constraints don't need to be fixed. As long as the next token is clear, we can force things.\n",
"\n",
"That means we can handle parsable formats. This constraint restricts the output to be a list of strings in JSON format."
],
"metadata": {
"id": "uMNzXvmQ6gU9"
}
},
{
"cell_type": "code",
"source": [
"import functools\n",
"\n",
"# We define a state machine that can recognize JSON character by character.\n",
"# This is very simple, but it would be perfectly possible to do the same thing\n",
"# with the full JSON spec, or even enforce a schema.\n",
"STATE_INVALID=\"invalid\"\n",
"STATE_START=\"start\"\n",
"STATE_AFTER_COMMA=\"after_comma\"\n",
"STATE_AFTER_STRING=\"after_string\"\n",
"STATE_IN_STRING=\"in_string\"\n",
"STATE_END=\"end\"\n",
"states = [STATE_INVALID, STATE_START, STATE_AFTER_COMMA, STATE_AFTER_STRING, STATE_IN_STRING, STATE_END]\n",
"\n",
"def next_json_state(state, char):\n",
" # INVALID is a dead-end\n",
" if state == STATE_INVALID:\n",
" return STATE_INVALID\n",
" # Whitespace is ignored in json\n",
" if char == \" \" or char == \"Ġ\":\n",
" return state\n",
" # Mini state machine\n",
" if state == STATE_START:\n",
" if char == \"[\":\n",
" return STATE_AFTER_COMMA\n",
" return STATE_INVALID\n",
" if state == STATE_AFTER_COMMA:\n",
" if char == \"\\\"\":\n",
" return STATE_IN_STRING\n",
" return STATE_INVALID\n",
" if state == STATE_AFTER_STRING:\n",
" if char == \"]\":\n",
" return STATE_END\n",
" if char == \",\":\n",
" return STATE_AFTER_COMMA\n",
" return STATE_INVALID\n",
" if state == STATE_IN_STRING:\n",
" if char == \"\\\"\":\n",
" return STATE_AFTER_STRING\n",
" if char == \",\":# Hack: Ban commas from occuring inside strings, to force the generation to actually use json\n",
" return STATE_INVALID\n",
" return STATE_IN_STRING\n",
" if state == STATE_END:\n",
" return STATE_INVALID\n",
" raise Exception(\"Unknown state\")\n",
"\n",
"# Find all allowed tokens in a given state.\n",
"# A token is allowed if it doesn't neccessarily lead to invalid json\n",
"# Tokens often have multiple characters in them,\n",
"\n",
"def token_ok_in_state(state, token):\n",
" state = functools.reduce(next_json_state, token, state)\n",
" return state != STATE_INVALID\n",
"\n",
"tokens_by_state = {state: get_inputs(lambda t: token_ok_in_state(state, t)) for state in states}\n",
"\n",
"tokens_by_state[STATE_END].append(tokenizer.eos_token_id)\n",
"\n",
"def json_list(input_str):\n",
" \"\"\"Restricts the answer to be a JSON array of strings\"\"\"\n",
" def prefix_allowed_tokens(batchId, inputIds):\n",
" # Get the answer so far\n",
" decoded = tokenizer.decode(inputIds)[len(input_str):]\n",
" # Parse what we have so far\n",
" state = functools.reduce(next_json_state, decoded, STATE_START)\n",
" # Only allow tokens that don't lead to invalid json\n",
" return tokens_by_state[state]\n",
" return prefix_allowed_tokens"
],
"metadata": {
"id": "Fe0VkWDc0YB7"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"input_str = \"\"\"The names of the world's largest cities as a json array of strings:\"\"\"\n",
"\n",
"generate_and_print(\n",
" input_str,\n",
" num_beams=10,\n",
" num_return_sequences=5,\n",
" no_repeat_ngram_size=3,\n",
" max_new_tokens=30,\n",
" remove_invalid_values=True,\n",
" prefix_allowed_tokens_fn=json_list(input_str),\n",
")"
],
"metadata": {
"id": "D_RWjRaCs0DE",
"outputId": "ff05efe5-cf7c-4010-849b-639c202a5ffe",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\"The names of the world's largest cities as a json array of strings:\"\n",
"Output:\n",
"----------------------------------------------------------------------------------------------------\n",
"score -0.73: ' [ \"London\", \"New York\", \"Paris\", \"Tokyo\", \"Moscow\", \"San Francisco\", \"Los Angeles\", \"Miami\", \"'\n",
"score -0.74: ' [ \"London\", \"New York\", \"Paris\", \"Tokyo\", \"Los Angeles\", \"San Francisco\", \"Toronto\", \"London\" ]'\n",
"score -0.74: ' [ \"London\", \"New York\", \"Paris\", \"Tokyo\", \"Los Angeles\", \"San Francisco\", \"Chicago\", \"London\" ]'\n",
"score -0.74: ' [ \"London\", \"New York\", \"Paris\", \"Tokyo\", \"San Francisco\", \"Los Angeles\", \"Chicago\", \"Washington\", \"'\n",
"score -0.74: ' [ \"London\", \"New York\", \"Paris\", \"Tokyo\", \"Los Angeles\", \"San Francisco\", \"Chicago\", \"Washington\", \"'\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"We get useful, parsable results. Without the constraint, you often get JSON, but not in any sensible format.\n",
"\n",
"Note: I had to ban commas from string literals, as otherwise GPT would prefer to just dump all the results in a single string literal.\n",
"\n",
"Note: GPT has chosen the whitespace, it's not forced by the constraint."
],
"metadata": {
"id": "---Nt1_68HOr"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment