Last active
March 29, 2023 20:09
-
-
Save cryppadotta/5455cd713040e409a4592719823881ff to your computer and use it in GitHub Desktop.
PickOptionChain: Given a prompt, call an LLM to generate options and present the user with a terminal UI to pick from the options. The result is the option the user picked
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
# PickOptionChain: Given a prompt, call an LLM to generate options and present the user with a terminal UI to pick from the options. The result is the option the user picked. | |
## Example: | |
```python | |
theme_query = "Let's create a poem. The first thing we need to do is pick a message or a theme. List out 8 themes and then ask the human to pick one of them" | |
llm = OpenAI(temperature=.7) | |
chain = PickOptionChain(llm=llm) | |
chain_results = chain.run(theme_query) | |
print(chain_results) | |
``` | |
""" | |
from typing import Dict, List | |
from pydantic import BaseModel | |
from langchain.output_parsers import PydanticOutputParser | |
from langchain.prompts import PromptTemplate | |
from langchain.chains.base import Chain | |
from langchain.chains import LLMChain | |
from langchain.schema import BaseLanguageModel | |
from pydantic import Field | |
import inquirer | |
class OptionModel(BaseModel): | |
setup: str = Field(description="context for the options") | |
options: List[str] = Field(description="list of options") | |
class PickOptionChain(Chain, BaseModel): | |
"""Chain that generates options from an LLM and then prompts the user for a selection from the reslts""" | |
llm: BaseLanguageModel | |
input_key: str = "input" #: :meta private: | |
output_key: str = "output" #: :meta private: | |
@property | |
def input_keys(self) -> List[str]: | |
"""Expect input key. | |
:meta private: | |
""" | |
return [self.input_key] | |
@property | |
def output_keys(self) -> List[str]: | |
"""Return output key. | |
:meta private: | |
""" | |
return [self.output_key] | |
@property | |
def _chain_type(self) -> str: | |
return "llm_requests_chain" | |
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: | |
text = inputs[self.input_key] | |
parser = PydanticOutputParser(pydantic_object=OptionModel) | |
prompt = PromptTemplate( | |
template="Answer the user query.\n{format_instructions}\n{query}\n", | |
input_variables=["query"], | |
partial_variables={"format_instructions": parser.get_format_instructions()} | |
) | |
llm_chain = LLMChain(llm=self.llm, prompt=prompt) | |
result = llm_chain.run(text) | |
parsed_result = parser.parse(result) | |
options = parsed_result.options | |
print("Select an option using arrow keys and press Enter to confirm:") | |
questions = [ | |
inquirer.List( | |
'selected_option', | |
message="Please choose an option", | |
choices=options, | |
), | |
] | |
answers = inquirer.prompt(questions) | |
if answers: | |
print(f"You have selected: {answers['selected_option']}") | |
else: | |
print("Invalid selection. Please try again.") | |
return {self.output_key: answers['selected_option']} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment