Last active
November 1, 2023 21:19
-
-
Save crypdick/a56337279b4f97d7a5e12616ee3d5b6a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import dotenv | |
import autogen | |
from autogen import AssistantAgent, UserProxyAgent | |
dotenv.load_dotenv() | |
config_list = [ | |
{"model": "gpt-3.5-turbo", "api_key": os.environ["OPENAI_API_KEY"]}, | |
] | |
llm_config = {"config_list": config_list, "temperature": 0} | |
autogen.ChatCompletion.start_logging() | |
human = UserProxyAgent( | |
name="human", | |
human_input_mode="TERMINATE", | |
system_message="The human who ask questions and give tasks.", | |
code_execution_config=False, | |
) | |
acceptable_tags = ["OK", "BAD", "AWESOME"] | |
tags_as_bullets = "" | |
for tag in acceptable_tags: | |
tags_as_bullets += f"\n- {tag}" | |
judge_prompt = f"You are presented with text and your job is to judge it. You with a tag from the following list. Choose ONLY from the list of tags provided here. The tag must be the final word in your response. Please start your answer with your rationale for choosing the tag. \n\n{tags_as_bullets}" | |
judge_prompt += "\n\nText:\n{text_to_judge}" | |
judge = AssistantAgent( | |
name="judge", | |
system_message=judge_prompt, | |
llm_config=llm_config, | |
) | |
def validate_judgement(recipient, messages, sender, config): | |
""" | |
Validate messsage from judge. If invalid, send feedback to judge agent. | |
""" | |
raw_message = messages[-1] | |
processed_message = extract_judgement(raw_message) | |
# judgement_ok = is_valid(processed_message) | |
judgement_ok = False # hard-coded to False for testing | |
if not judgement_ok: | |
feedback = f"Invalid judgement. The final word must be one of the admissible tags: {acceptable_tags}. Do not apoligize for your previous answer, just give me the tag." | |
if "callback" in config and config["callback"] is not None: | |
callback = config["callback"] | |
callback(sender, recipient, feedback) | |
print(f"Feedback sent to: {recipient.name}") | |
else: | |
print(f"Callback not found.") | |
return False, None # required to ensure the agent communication flow continues | |
judge.register_reply( | |
[judge, None], # only run this validator when the judge is the sender | |
reply_func=validate_judgement, | |
config={"callback": None}, | |
) | |
def _reset_agents(): | |
human.reset() | |
judge.reset() | |
def chat(problem): | |
_reset_agents() | |
groupchat = autogen.GroupChat( | |
agents=[human, judge], | |
messages=[], | |
max_round=5, | |
) | |
manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config) | |
human.initiate_chat(manager, message=problem) | |
def is_valid(judgement): | |
if judgement in acceptable_tags: | |
return True | |
else: | |
return False | |
def extract_judgement(judgement): | |
"""Extract the final word from the response.""" | |
print(f"raw judgement: {judgement}") | |
judgement = judgement.split(" ")[-1] | |
judgement = judgement.replace(".", "") | |
judgement = judgement.strip() | |
judgement = judgement.upper() | |
print(f"Extracted judgement: {judgement}") | |
return judgement | |
def main(): | |
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." | |
problem = judge_prompt.format(text_to_judge=text) | |
chat(problem) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment