Last active
October 25, 2021 06:12
-
-
Save Aivean/afc9fe02743acc5d8a651081e87c4c80 to your computer and use it in GitHub Desktop.
Evaluate Codex accuracy on random list sorting task
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import openai, os | |
from dotenv import load_dotenv | |
load_dotenv() | |
OPENAI_KEY = os.getenv('OPENAI_KEY') | |
openai.api_key = OPENAI_KEY | |
template = """The sort function can be used to sort a list in ascending, descending or user defined | |
order. | |
To sort the list in ascending order, simply call list.sort(). This will sort a list | |
of integers in ascending order so that the smallest integer will be first in the list | |
and the largest integer will be the last. | |
For example: | |
list = {} | |
list.sort() => """ | |
elements_n = 20 | |
n = 20 | |
correct = 0 | |
for i in range(n): | |
# generate a list of random integers | |
list = [random.randint(0, 100) for i in range(elements_n)] | |
result = openai.Completion.create( | |
engine='davinci-codex', prompt=template.format(list), stop='\n', max_tokens=100, temperature=0, n=1 | |
) | |
# parse list from string result.choices[0].text | |
output = result.choices[0].text.replace('[', '').replace(']', '').split(', ') | |
output = [int(i) for i in output] | |
eta = sorted(list) | |
print(list) | |
print(eta) | |
print(output) | |
# calculate the length of longest common subsequence of eta and output | |
# (using dp) | |
dp = [[0 for i in range(len(output) + 1)] for j in range(len(eta) + 1)] | |
for i in range(1, len(eta) + 1): | |
for j in range(1, len(output) + 1): | |
if eta[i - 1] == output[j - 1]: | |
dp[i][j] = dp[i - 1][j - 1] + 1 | |
else: | |
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1]) | |
errors = len(eta) - dp[-1][-1] | |
print(f'errors: {errors}') | |
if errors == 0: | |
correct += 1 | |
print() | |
print(f'{correct}/{n}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment