Created
May 19, 2024 03:52
-
-
Save alecchendev/a1cc7fe4cf24c489a5809f717a67ad68 to your computer and use it in GitHub Desktop.
Using an LLM to do information retrieval super easily
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import sys | |
import json | |
def check_message(message): | |
url = "http://localhost:11434/api/generate" | |
prompt = f''' | |
Based on the following message, does Bob like chocolate ice cream? Answer true or false. | |
Examples: | |
"sally likes chocolate ice cream": false | |
"Bob said he really enjoys chocolate ice cream.": true | |
"bob really likes vanilla ice cream": false | |
"Bob is quite a fan of ice cream that is chocolate flavored.": true | |
"Chocolate ice cream is the favorite flavor of a man named bob": true | |
Return all lowercase and do not include any other characters as a part of your answer. | |
Only answer true if you have enough information, otherwise answer false. | |
"{message}": ''' | |
payload = { | |
"model": "llama3", | |
"prompt": prompt | |
} | |
response = requests.post(url, json=payload) | |
lines = response.text.splitlines() | |
assert len(lines) > 0 | |
decoder = json.JSONDecoder() | |
firstline = decoder.decode(lines[0]) | |
answer: str = firstline["response"] | |
# print("answer:", answer) | |
if len(lines) > 1: | |
length = len(lines) | |
assert len(lines) == 2, f"Expected 2, got {length}: {"\n".join(lines)}" | |
endlines = decoder.decode(lines[1]) | |
assert endlines["response"] == "" | |
assert endlines["done"] == True | |
else: | |
assert firstline["done"] == True | |
return answer.lower() == "true" | |
def run_tests(): | |
test_cases = { | |
# Given as examples | |
# "sally likes chocolate ice cream": False, | |
# "Bob said he really enjoys chocolate ice cream.": True, | |
# "bob really likes vanilla ice cream": False, | |
# "Bob is quite a fan of ice cream that is chocolate flavored.": True, | |
# "Chocolate ice cream is the favorite flavor of a man named bob": True, | |
# Unseen | |
"Bob says he finds chocolate ice cream to be fantastic": True, | |
"Sally says bob likes chocolate ice cream": False, | |
"Bob friggin loves strawberry ice cream": False, | |
"Bob eats chocolate ice cream often but doesn't like it": False, | |
"Bill loves chocolate ice cream": False, | |
"bobby loves chocolate ice cream": False, | |
# Interesting cases | |
# "Sally says bob likes chocolate ice cream and she knows him really well": True, | |
} | |
total_wrong = 0 | |
rounds = 10 | |
for message, expected in test_cases.items(): | |
# run_tests() | |
success = 0 | |
for _ in range(rounds): | |
result = check_message(message) | |
# print(f'Test message: "{message}"') | |
# print(f'Expected: {expected}, Got: {result}') | |
success += int(result == expected) | |
total_wrong += int(result != expected) | |
# assert result == expected, f'Test failed for message: "{message}"' | |
print(f'{message}:\n{success}/{rounds}\n') | |
print("Success rate:", 1 - total_wrong / (rounds * len(test_cases))) | |
assert total_wrong <= 2 | |
def main(): | |
# Example usage | |
message = "Bob said he really enjoys chocolate ice cream." | |
print(f'Test message: "{message}"') | |
result = check_message(message) | |
print(f'The statement that Bob likes chocolate ice cream based on the message is: {result}') | |
if __name__ == "__main__": | |
if len(sys.argv) > 1 and sys.argv[1] == "test": | |
run_tests() | |
else: | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment