Skip to content

Instantly share code, notes, and snippets.

@mgraczyk
Created April 8, 2024 03:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mgraczyk/a4152785419c32c0c6829ced4b687616 to your computer and use it in GitHub Desktop.
Save mgraczyk/a4152785419c32c0c6829ced4b687616 to your computer and use it in GitHub Desktop.
import asyncio as asyncio # keep
import enum
import random
from typing import Callable
from typing import Literal
from typing import NotRequired
from typing import TypedDict
import lxml.etree
from llm.types import ChatModelName
from llm.utils import answer_few_shot_chat as answer_few_shot_chat
from llm.utils import answer_one_shot_chat as answer_one_shot_chat
class Tokens(enum.Enum):
A_SHARP = 1
SHARP_A = 2
B_SHARP = 3
SHARP_B = 4
TOKEN_TO_STR = {
Tokens.A_SHARP: "A#",
Tokens.SHARP_A: "#A",
Tokens.B_SHARP: "B#",
Tokens.SHARP_B: "#B",
}
STR_TO_TOKEN = {v: k for k, v in TOKEN_TO_STR.items()}
RULES: dict[tuple[Tokens, Tokens], list[Tokens]] = {
(Tokens.A_SHARP, Tokens.SHARP_A): [],
(Tokens.A_SHARP, Tokens.SHARP_B): [Tokens.SHARP_B, Tokens.A_SHARP],
(Tokens.B_SHARP, Tokens.SHARP_A): [Tokens.SHARP_A, Tokens.B_SHARP],
(Tokens.B_SHARP, Tokens.SHARP_B): [],
}
CONVERSION = {
Tokens.A_SHARP: "A",
Tokens.SHARP_A: "X",
Tokens.B_SHARP: "B",
Tokens.SHARP_B: "Y",
}
def _parse_sentence(sentence: str) -> list[Tokens]:
return [STR_TO_TOKEN[p] for p in sentence.split(" ") if p]
def _serialize_tokens(tokens: list[Tokens]) -> str:
return " ".join(TOKEN_TO_STR[p] for p in tokens)
def _serialize_tokens_with_conversion(tokens: list[Tokens]) -> str:
return " ".join(CONVERSION[p] for p in tokens)
def _do_one_replacement(tokens: list[Tokens]) -> list[Tokens]:
for i in range(len(tokens) - 1):
replacement = RULES.get((tokens[i], tokens[i + 1]))
if replacement is not None:
replacement = RULES[(tokens[i], tokens[i + 1])]
return tokens[:i] + list(replacement) + tokens[i + 2 :]
return tokens
def _do_replacement(tokens: list[Tokens]) -> list[Tokens]:
while True:
next_tokens = _do_one_replacement(tokens)
if next_tokens is tokens:
break
tokens = next_tokens
return tokens
class State(TypedDict):
reverse: Literal["false", "true"]
prefix: list[Tokens]
current: list[Tokens]
suffix: list[Tokens]
class Transition(TypedDict):
# List of properties and values to compare, and whether == or != should be used.
matcher: list[tuple[str, str | list[Tokens], bool]]
# State transition
transition: Callable[[State], State]
is_halting: NotRequired[bool]
def _state_to_string(s: State) -> str:
reverse = f"<reverse>{s['reverse']}</reverse>"
prefix = f"<prefix>{_serialize_tokens_with_conversion(s['prefix'])}</prefix>"
current = f"<current>{_serialize_tokens_with_conversion(s['current'])}</current>"
suffix = f"<suffix>{_serialize_tokens_with_conversion(s['suffix'])}</suffix>"
inner = "\n ".join([reverse, prefix, current, suffix])
return f"<state>\n {inner}\n</state>"
def _do_llm_algorithm(tokens: list[Tokens]) -> tuple[list[Tokens], str]:
def default_transition(s: State) -> State:
return {
**s,
"prefix": s["prefix"] + s["current"][:1],
"current": s["current"][1:] + s["suffix"][:1],
"suffix": s["suffix"][1:],
}
transitions = [
Transition(
matcher=[("reverse", "true", True), ("prefix", [], True)],
transition=lambda s: {**s, "reverse": "false"},
),
Transition(
matcher=[("reverse", "true", True), ("prefix", [], False), ("current", [], True)],
transition=lambda s: {
"reverse": "false",
"current": [s["prefix"][-1], *s["suffix"][:1]],
"prefix": s["prefix"][:-1],
"suffix": s["suffix"][1:],
},
),
Transition(
matcher=[("reverse", "true", True), ("prefix", [], False), ("current", [], False)],
transition=lambda s: {
"reverse": "false",
"suffix": [s["current"][1], *s["suffix"]],
"current": [s["prefix"][-1], s["current"][0]],
"prefix": s["prefix"][:-1],
},
),
Transition(
matcher=[("reverse", "false", True), ("current", [], True), ("suffix", [], False)],
transition=lambda s: {
**s,
"current": s["suffix"][:2],
"suffix": s["suffix"][2:],
},
),
Transition(
matcher=[("reverse", "false", True), ("current", [], True), ("suffix", [], True)],
transition=lambda s: {
**s,
"current": [],
},
is_halting=True,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.A_SHARP, Tokens.SHARP_A], True)],
transition=lambda s: {
**s,
"current": [],
"reverse": "true",
},
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.B_SHARP, Tokens.SHARP_B], True)],
transition=lambda s: {
**s,
"current": [],
"reverse": "true",
},
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.A_SHARP, Tokens.SHARP_B], True)],
transition=lambda s: {
**s,
"current": [Tokens.SHARP_B, Tokens.A_SHARP],
"reverse": "true",
},
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.B_SHARP, Tokens.SHARP_A], True)],
transition=lambda s: {
**s,
"current": [Tokens.SHARP_A, Tokens.B_SHARP],
"reverse": "true",
},
),
# Conver the rest one by one into separate transition rules, instead of grouping the logical OR.
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_A, Tokens.A_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_B, Tokens.B_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_A, Tokens.B_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_B, Tokens.A_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_A, Tokens.SHARP_A], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_A, Tokens.SHARP_B], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_B, Tokens.SHARP_A], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_B, Tokens.SHARP_B], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.A_SHARP, Tokens.A_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.A_SHARP, Tokens.B_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.B_SHARP, Tokens.A_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.B_SHARP, Tokens.B_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.A_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.B_SHARP], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_A], True)],
transition=default_transition,
),
Transition(
matcher=[("reverse", "false", True), ("current", [Tokens.SHARP_B], True)],
transition=default_transition,
),
]
s = State(reverse="false", prefix=[], current=[], suffix=tokens)
def _serialize_rule_value(v: str | list[Tokens]) -> str:
if isinstance(v, str):
return f'"{v}"'
if isinstance(v, list):
return f'"{_serialize_tokens_with_conversion(v)}"'
result_strings = []
result_strings.append("<thinking>Convert the problem input into the inital state</thinking>")
result_strings.append(_state_to_string(s))
while True:
for transition in transitions:
if all((s[k] == v if eq else s[k] != v) for k, v, eq in transition["matcher"]):
break
else:
assert False, f"No transition found for state {s}"
s = transition["transition"](s)
match_str = " and ".join(
f"{k} {'==' if eq else '!='} {_serialize_rule_value(v)}" for k, v, eq in transition["matcher"]
)
result_strings.append(f"<thinking>rule {match_str}</thinking>")
result_strings.append(_state_to_string(s))
if transition.get("is_halting"):
result_strings.append(f"<solution>{_serialize_tokens(s['prefix'])}</solution>")
return s["prefix"], "\n".join(result_strings)
assert False
def _show_expected_llm_rules(tokens: list[Tokens]) -> None:
print("Expected LLM rules:")
for i in range(len(tokens) - 1):
replacement = RULES.get((tokens[i], tokens[i + 1]))
if replacement is not None:
print(f"Rule: {TOKEN_TO_STR[tokens[i]]} {TOKEN_TO_STR[tokens[i + 1]]} -> {replacement}")
def _get_examples() -> list[tuple[str, str]]:
return [("B# A# #B #A B#", "B#"), ("#B B# #A B# #A B# B#", "#B #A #A B# B# B# B#")]
def _test_on_examples():
examples = _get_examples()
num_errors = 0
for example in examples:
sentence, expected = example
tokens = _parse_sentence(sentence)
result = _do_replacement(tokens)
result = _serialize_tokens(result)
if result != expected:
print(f"Test failed for {sentence}: expected {expected}, got {result}")
num_errors += 1
if num_errors == 0:
print("All tests passed!")
else:
print(f"{num_errors}/{len(examples)} tests failed.")
SYSTEM_PROMPT = r"""
You will execute a string manipulation algorithm on a user provided string to find the solution to a puzzle.
You are given a problem input string like this:
<problem>B# A# #B #A B#</problem>
You must execute the algorithm specified below exactly, applying each step where appropriate to find the solution.
The algorithm consists of a current state, and rules that you must apply to the state to get the next state.
One of the rules will terminate the algorithm and return a final solution.
In each step of the algorithm, you will first explain what you are thinking, then provide the current state after applying the rule for that step.
THE ALGORITHM
The algorithm consists of converting the input into the initial state, then applying rules repeatedly until you return the solution and terminate.
The state consists of four parts, prefix, suffix, the current string, and a boolean called reverse. Here is an example state:
<state>
<reverse>false</reverse>
<prefix>{B#} {A#}</prefix>
<current>{#B} {#A}</current>
<suffix>{#B} {#A} {#A} {#B}</suffix>
</state>
Next I will describe the steps of the algorithm.
Follow these steps exactly.
First you must convert the input.
You will be given a problem string like this:
<problem>B# A# #B #A B#</problem>
Convert the problem to an initial state by copying the string into the suffix and setting reverse to false.
When converting the problem to the suffix, convert each token using these conversions:
"A#" -> "{A#}"
"#A" -> "{#A}"
"B#" -> "{B#}"
"#B" -> "{#B}"
For example, given this problem as input:
<problem>B# A# #B #A B#</problem>
You would begin your response with this combination of your thoughts, and the new state.
You would not end your response until you have a solution:
<thinking>Convert the problem input into the inital state</thinking>
<state>
<reverse>false</reverse>
<prefix></prefix>
<current></current>
<suffix>{B#} {A#} {#B} {#A} {B#}</suffix>
</state>
Next, you will repeatedly apply EXACTLY ONE rule to the state.
I will write the rules in python pseudocode with comments.
Tell me if any of these rules or code is ambiguous or confusiong, and I will clarify.
In the code, local variables "reverse", "prefix", "current", and "suffix" represent the state.
```python
if reverse == "true" and prefix == "":
# reverse with no prefix, skip
reverse = "false"
if reverse == "true" and prefix != "" and current == "":
# reverse by moving a token from both prefix and suffix.
reverse = "false"
current = prefix[-1] + suffix[0]
prefix = prefix[:-1]
suffix = suffix[1:]
elif reverse == "true" and prefix != "" and current != "":
# reverse by shifting one token from current to suffix, and one token from prefix to current.
reverse = "false"
suffix = current[1] + suffix
current = prefix[-1] + current[0]
prefix = prefix[:-1]
elif reverse == "false" and current == "" and suffix != "":
# Move two tokens from the beginning of suffix to current.
current = suffix[:2]
suffix = suffix[2:]
elif reverse == "false" and current == "" and suffix == "":
# Terminate the algorithm and return the solution as XML
solution = prefix
return solution
elif reverse == "false" and current == "{A#} {#A}":
current = ""
reverse = "true"
elif reverse == "false" and current == "{B#} {#B}":
current = ""
reverse = "true"
elif reverse == "false" and current == "{A#} {#B}":
current = "{#B} {A#}"
reverse = "true"
elif "reverse" == "false" and current == "{B#} {#A}":
current = "{#A} {B#}"
reverse = "true"
elif "reverse" == "false" and current == "{#A} {A#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#B} {B#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#A} {B#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#B} {A#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#A} {#A}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#A} {#B}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#B} {#A}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#B} {#B}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{A#} {A#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{A#} {B#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{B#} {A#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{B#} {B#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{A#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{B#}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#A}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
elif "reverse" == "false" and current == "{#B}":
# Shift one token from current to prefix, and one token from suffix to current.
prefix = prefix + current[:1]
current = current[1:] + suffix[:1]
suffix = suffix[1:]
```
These rules cover every possible case.
Remember to always apply exactly one rule in each step, and always explain your thinking before providing the new state.
Here is an example state transition.
Suppose the previous state was:
<state>
<reverse>false</reverse>
<prefix>{#A} {#A}</prefix>
<current>{A#} {#A}</current>
<suffix>{A#} {#A} {#A} {B#} {#A} {#B} {#A} {#B}</suffix>
</state>
You would continue your reply with something like this:
<thinking>rule reverse == "false" and current == "{A#} {#A}"</thinking>
<state>
<reverse>true</reverse>
<prefix>{#A} {#A}</prefix>
<current></current>
<suffix>{A#} {#A} {A#} {B#} {#A} {#B} {#A} {#B}</suffix>
</state>
Here is another example.
Previous state:
<state>
<reverse>true</reverse>
<prefix>{B#} {A#} {B#} {#B} {B#}</prefix>
<current>{#B} {#A}</current>
<suffix>{#A} {B#} {#A} {B#} {B#}</suffix>
</state>
You would continue your reply with something like this:
<thinking>rule reverse == "true" and prefix != "" and current == ""</thinking>
<state>
<reverse>true</reverse>
<prefix>{B#} {A#} {B#} {#B}</prefix>
<current>{B#} {#B}</current>
<suffix>{#A} {#A} {B#} {#A} {B#} {B#}</suffix>
</state>
When the state should have the final rule applied with current == "" and suffix == "", return the solution as an XML tag.
Remember to convert back to the input formating, using these rules:
"{A#}" -> "A#"
"{#A}" -> "#A"
"{B#}" -> "B#"
"{#B}" -> "#B"
For example, if the state is:
<state>
<reverse>false</reverse>
<prefix>{#B} {#B} {#B} {A#} {A#} {B#}</prefix>
<current></current>
<suffix></suffix>
</state>
You would convert the output back to the original format and finish your reply with:
<thinking>rule reverse == "false" and current == "" and suffix == ""</thinking>
<solution>#B #B #B A# A# B#</solution>
END OF THE ALGORITHM
Do not say anything that is not in the XML tags <thinking>, <state>, <solution>.
No other explanation of any kind should be given outside of the thinking tag.
Keep responding until you have a final solution, and respond with your final solution in a <solution> tag.
Do not stop until you have responded with a solution!
""".strip()
def _make_user_message(instance: str) -> str:
return f"<problem>{instance}</problem>"
def _gen_instances(num_instances: int, token_length: int) -> list[str]:
all_tokens = (Tokens.A_SHARP, Tokens.SHARP_A, Tokens.B_SHARP, Tokens.SHARP_B)
return [
_serialize_tokens(random.choices(all_tokens, k=token_length)) for _ in range(num_instances)
]
def _test_llm_algo_on_random_instances():
instances = _gen_instances(1, 7)
num_errors = 0
for instance in instances:
tokens = _parse_sentence(instance)
expected = _do_replacement(tokens)
print(f"{instance} -> {_serialize_tokens(expected)}")
actual, _ = _do_llm_algorithm(tokens)
if actual != expected:
print(f"Test failed for {instance}: expected {expected}, got {actual}")
num_errors += 1
if num_errors == 0:
print(f"All {len(instances)} tests passed!")
else:
print(
f"{num_errors}/{len(instances)} tests failed: {100 * num_errors / len(instances):.2f}% failed"
)
def _do_prompt_replacement(prompt: str) -> str:
for k, v in TOKEN_TO_STR.items():
prompt = prompt.replace("{" + v + "}", CONVERSION[k])
return prompt
def _do_reverse_prompt_replacement(prompt: str) -> str:
for k, v in TOKEN_TO_STR.items():
prompt = prompt.replace(CONVERSION[k], v)
return prompt
async def _test_llm():
instances = _gen_instances(1, 8)
xml_from_string = lxml.etree.fromstring
system_message = _do_prompt_replacement(SYSTEM_PROMPT)
example_instance = "#B B# #A B# #A B# #A"
example_user_message = _make_user_message(example_instance)
_, example_ai_message = _do_llm_algorithm(_parse_sentence(example_instance))
num_errors = 0
for instance in instances:
tokens = _parse_sentence(instance)
expected = _do_replacement(tokens)
expected_str = _serialize_tokens(expected)
print("*" * 60)
print(f"Expecting: {instance} -> {_serialize_tokens(expected)}")
user_message = _make_user_message(instance)
result = await answer_few_shot_chat(
system_message=system_message,
user_messages=[example_user_message, user_message],
ai_messages=[example_ai_message],
model_name=ChatModelName.CLAUDE_3_HAIKU,
max_tokens=8000,
)
lines = [l.strip() for l in result.split("\n") if l.strip() and "<solution>" in l]
if len(lines) != 1:
print("LLM output:", result)
raise ValueError(f"llm output had {len(lines)} solutions instead of 1")
doc = xml_from_string(lines[0])
solutions = [e.text for e in doc.xpath("//solution")]
assert len(solutions) == 1
solution = solutions[0]
if solution == expected_str:
print(f"Test passed with {len(result)} characters!")
else:
print(f"Test failed for {instance}: expected {expected_str}, got {solution}")
with open("llm_output.txt", "w") as f:
f.write(result)
with open("expected_llm_output.txt", "w") as f:
_, expected_llm_output = _do_llm_algorithm(_parse_sentence(instance))
f.write(expected_llm_output)
num_errors += 1
print("")
if num_errors == 0:
print("All tests passed!")
else:
print(
f"{num_errors}/{len(instances)} tests failed: {100 * num_errors / len(instances):.2f}% failed"
)
if __name__ == "__main__":
random.seed(1338)
# _run_on_random_instances()
# _test_on_examples()
# _test_llm_algo_on_random_instances()
asyncio.run(_test_llm())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment