Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save lukestanley/ddfd48489ecb5bc97e4d2bca24ca229f to your computer and use it in GitHub Desktop.
Save lukestanley/ddfd48489ecb5bc97e4d2bca24ca229f to your computer and use it in GitHub Desktop.
I used Anthropic to suggest code for fixing langchain-decorators. Here is the prompt. I didn't use all of it's suggestions but it had a working idea.
Human:
<file path=setup.py>
import setuptools
with open("src/langchain_decorators/__init__.py", "rt") as f:
for line in f.readlines():
if line.startswith("__version__"):
__version__ = line.split("=")[1].strip(" \n\"")
setuptools.setup(name='langchain-decorators',
version=__version__,
description='syntactic sugar for langchain',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
author='Juraj Bezdek',
author_email='juraj.bezdek@blip.solutions',
url='https://github.com/ju-bezdek/langchain-decorators',
package_dir={"": "src"},
packages=setuptools.find_packages(where="src"),
license='MIT License',
zip_safe=False,
keywords='langchain',
classifiers=[
],
python_requires='>=3.9',
install_requires=[
"langchain"
]
)
</file>
<file path=crawl.py>
import os
from pathlib import Path
git_dir = Path(".") # set your git directory path here
output_file = "output.txt"
with open(output_file, "w") as f:
for file in git_dir.glob("**/*.py"):
f.write(f'<file path={str(file)}>\n')
with open(file, "r") as py_file:
f.write(py_file.read())
f.write("\n</file>\n")
</file>
<file path=src/test.py>
from langchain_decorators import PromptTypes
from pydantic import BaseModel
from pydantic import Field
from langchain_decorators import llm_prompt
from langchain.chat_models import ChatAnthropic
from langchain_decorators import GlobalSettings
GlobalSettings.define_settings(
default_llm=ChatAnthropic(temperature=0.0),
default_streaming_llm=ChatAnthropic(temperature=0.0, streaming=True),
)
PromptTypes.AGENT_REASONING.llm = ChatAnthropic()
class ScoreFormat(BaseModel):
score: float = Field(description="How calm the topic is from 0 to 1")
@llm_prompt()
def rater(_sentence: str) -> ScoreFormat:
"""Score how calm this sentence is: {_sentence}
{FORMAT_INSTRUCTIONS}
"""
return
out = rater(_sentence="people on fire!")
output = out.score
print("output", output, type(output))
</file>
<file path=src/langchain_decorators/__init__.py>
from .common import LogColors, GlobalSettings, print_log, PromptTypes, PromptTypeSettings
from .prompt_decorator import PromptDecoratorTemplate
from .streaming_context import StreamingContext
from .prompt_decorator import llm_prompt
__version__="0.0.2"
</file>
<file path=src/langchain_decorators/common.py>
import logging
import yaml
from enum import Enum
from typing import Any, Union
from pydantic import BaseModel, Extra
from langchain.llms.base import BaseLanguageModel
from langchain.chat_models import ChatOpenAI
class GlobalSettings(BaseModel):
default_llm: BaseLanguageModel = None
default_streaming_llm: BaseLanguageModel = None
logging_level: int = logging.INFO
stdout_logging: bool = True
verbose: bool = False
class Config:
allow_population_by_field_name = True
extra = Extra.allow
@classmethod
def define_settings(cls,
settings_type="default",
default_llm=ChatOpenAI(temperature=0.0),
default_streaming_llm=ChatOpenAI(
temperature=0.0, streaming=True),
logging_level=logging.INFO,
stdout_logging: bool = True,
verbose=False,
**kwargs
):
settings = cls(default_llm=default_llm, default_streaming_llm=default_streaming_llm,
logging_level=logging_level, stdout_logging=stdout_logging, verbose=verbose, **kwargs)
if not hasattr(GlobalSettings, "registry"):
setattr(GlobalSettings, "registry", {})
GlobalSettings.registry[settings_type] = settings
@classmethod
def get_current_settings(cls) -> "GlobalSettings":
if not hasattr(GlobalSettings, "settings_type"):
setattr(GlobalSettings, "settings_type", "default")
if not hasattr(GlobalSettings, "registry"):
GlobalSettings.define_settings()
return GlobalSettings.registry[GlobalSettings.settings_type]
@classmethod
def switch_settings(cls, project_name):
GlobalSettings.settings_type = project_name
class LogColors(Enum):
WHITE_BOLD = "\033[1m"
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
DARK_GRAY = '\033[90m'
# Define some reset codes to restore the default text color
RESET = '\033[0m'
def print_log(log_object: Any, log_level: int, color: LogColors = None):
settings = GlobalSettings.get_current_settings()
if settings.logging_level <= log_level or settings.verbose:
if isinstance(log_object, str):
pass
elif isinstance(log_object, dict):
log_object = yaml.safe_dump(log_object)
elif isinstance(log_object, BaseModel):
log_object = yaml.safe_dump(log_object.dict())
if color is None:
if log_level >= logging.ERROR:
color = LogColors.RED
elif log_level >= logging.WARNING:
color = LogColors.YELLOW
elif log_level >= logging.INFO:
color = LogColors.GREEN
else:
color = LogColors.DARK_GRAY
if type(color) is LogColors:
color = color.value
reset = LogColors.RESET.value if color else ""
print(f"{color}{log_object}{reset}\n", flush=True)
class PromptTypeSettings:
def __init__(self, llm: BaseLanguageModel = None, color: LogColors = None, log_level: Union[int, str] = "info", capture_stream: bool = False):
self.color = color
if isinstance(log_level, str):
log_level = getattr(logging, log_level.upper())
self.log_level = log_level
self.capture_stream = capture_stream
self.llm = llm
def as_verbose(self):
return PromptTypeSettings(llm=self.llm, color=self.color, log_level=100, capture_stream=self.capture_stream)
class PromptTypes:
UNDEFINED: PromptTypeSettings = PromptTypeSettings(
color=LogColors.GREEN, log_level=logging.DEBUG)
AGENT_REASONING: PromptTypeSettings = PromptTypeSettings(
color=LogColors.GREEN, log_level=logging.INFO)
TOOL: PromptTypeSettings = PromptTypeSettings(
color=LogColors.BLUE, log_level=logging.INFO)
FINAL_OUTPUT: PromptTypeSettings = PromptTypeSettings(
color=LogColors.YELLOW, log_level=logging.INFO)
</file>
<file path=src/langchain_decorators/streaming_context.py>
import contextvars
from typing import Any, Callable
class StreamingContext():
from langchain.callbacks.base import AsyncCallbackHandler
class StreamingContextCallback(AsyncCallbackHandler):
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
StreamingContext.get_context().on_new_token(token)
async def on_llm_end(self, response, *args, **kwargs):
if StreamingContext.get_context().stream_to_stdout:
print()
context_var = contextvars.ContextVar('streaming_context')
def __init__(self, callback: Callable[[str], None] = None, stream_to_stdout: bool = False) -> None:
self.callback = callback
self.stream_to_stdout = stream_to_stdout
self.token_colors = ['\033[90m', '\033[0m']
def __enter__(self):
self.__class__.context_var.set(self)
@classmethod
def get_context(cls) -> 'StreamingContext':
return cls.context_var.get("streaming_context")
def on_new_token(self, token: str):
if self.callback:
self.callback(token)
if self.stream_to_stdout:
reset_color = '\033[0m'
current_color = self.token_colors[0]
self.token_colors.reverse()
print('{}{}{}'.format(current_color, token if token !=
"" else '\u2022', reset_color), end='')
def __exit__(self, exc_type, exc_value, traceback):
self.context_var.set(None)
</file>
<file path=src/langchain_decorators/prompt_decorator.py>
from ast import Tuple
import logging
import inspect
from functools import wraps
from textwrap import dedent
from typing import Callable, List, Optional, Sequence, Union
from langchain import LLMChain, PromptTemplate
from langchain.schema import BaseOutputParser
from langchain.llms.base import BaseLanguageModel
try:
from promptwatch import register_prompt_template
except ImportError:
print('To use promptwatch try: pip install promptwatch')
from .common import *
from .prompt_template import PromptDecoratorTemplate
from .output_parsers import *
from .streaming_context import StreamingContext
def llm_prompt(
# do not change the order of this first parameter unless you will change also the fist few lines... since we are handling cases when decorator is used with and without arguments too, than this will be the func
prompt_type: PromptTypeSettings = PromptTypes.UNDEFINED,
template_format: str = "f-string-extra",
output_parser: Union[str, None, BaseOutputParser] = "auto",
stop_tokens: List[str] = None,
template_name: str = None,
template_version: str = None,
capture_stream: bool = None,
llm: Optional[BaseLanguageModel] = None,
format_instructions_parameter_key: str = "FORMAT_INSTRUCTIONS",
retry_on_output_parsing_error: bool = True,
verbose: bool = None,
):
"""
Decorator for functions that turns a regular function into a LLM prompt executed with default model and settings.
This can be applied on any function that has a docstring with a prompt template.
If the function is async, the prompt will be executed asynchronously (with all the langchain async infrastructure).
Note that the code of the function will never be executed...
Args:
`prompt_type`: (Optional[PromptTypeSettings]) - This allows you mark your prompt with one of the predefined prompt types (see PromptTypes class - but you can subclass it!) to predefine some settings like LLM or style and color of logging into console.
`template_format` (Optional[str]): one of [ `f-string` | `f-string-extra` ] ... f-string-extra is a superset of f-string template formats, enabling for optional sections.
`output_formatting` (Optional[str]): one of [ `auto` | `json` | `str` | `list` ] or `None` or langchain OutputParser object - you can control how will the output be parsed.
`auto` - default - determine the output type automatically based on output type annotations
`str` or `None` - will return plain string output
`json` - will parse the output as json
`markdown` - will parse the output as markdown sections, the name of each section will be returned as a key and the content as a value. For nested sections, the value will be a dict with the same structure.
`pydantic` - will parse the output as json and then convert into a pydantic model
`list` - will parse bullet or numbered list (each item on a new line) as a list
`stop_tokens` (Optional[List[str]]): list of stop tokens to instruct the LLM to stop generating text when it encounters any of these tokens. If not provided, the default stop tokens of the LLM will be used.
`format_instructions_parameter_key` - name of the format instructions parameter - this will enable you to include the instructions on how LLM should format the output, generated by the output_parsers
... if you include this into your prompt (docs), you don't need to reinvent the formatting instructions.
This works pretty well if you have an annotated pydantic model as an function output. If you are expecting a dict, you should probably include your own formatting instructions, since there is not much to infer from a dict structure.
`retry_on_output_parsing_error` - whether to try to re-format the output if the output parser fails to parse the output by another LLM call
`verbose` - whether to print the response from LLM into console
"""
if callable(prompt_type):
# this is the case when the decorator is called without arguments
# we initialize params with default values
func = prompt_type
prompt_type = None
else:
func = None
if verbose is None:
verbose = GlobalSettings.get_current_settings().verbose
if verbose:
if prompt_type:
prompt_type = prompt_type.as_verbose()
else:
prompt_type = PromptTypeSettings(
color=LogColors.DARK_GRAY, log_level=100, capture_stream=capture_stream)
def decorator(func):
prompt_str = dedent(func.__doc__)
name = func.__name__
full_name = f"{func.__module__}.{name}" if func.__module__ != "__main__" else name
is_async = inspect.iscoroutinefunction(func)
if prompt_type:
_capture_stream = prompt_type.capture_stream if capture_stream is None else capture_stream
else:
_capture_stream = capture_stream
if _capture_stream:
if not is_async:
print_log(
f"Warning: capture_stream=True is only supported for async functions. Ignoring capture_stream for {full_name}", logging.WARNING, LogColors.YELLOW)
_capture_stream = False
else:
if not StreamingContext.get_context():
print_log(
f"Debug: Not inside StreamingContext. Ignoring capture_stream for {full_name}", logging.DEBUG, LogColors.DARK_GRAY)
_capture_stream = False
def prepare_call_args(*args, **kwargs):
global_settings = GlobalSettings.get_current_settings()
if not llm:
if _capture_stream:
if not global_settings.default_streaming_llm:
print_log(
f"Warning: capture_stream on {name} is on, but the default LLM {llm} doesn't seem to be supporting streaming.", logging.WARNING, LogColors.YELLOW)
prompt_llm = global_settings.default_streaming_llm or global_settings.default_llm
else:
prompt_llm = global_settings.default_llm
else:
prompt_llm = llm
if _capture_stream:
if hasattr(llm, "streaming"):
if not getattr(llm, "streaming"):
print_log(
f"Warning: capture_stream on {name} is on, but the provided LLM {llm} doesn't have streaming on! Stream wont be captured", logging.WARNING, LogColors.YELLOW)
else:
print_log(
f"Warning: capture_stream on {name} is on, but the provided LLM {llm} doesn't seem to be supporting streaming.", logging.WARNING, LogColors.YELLOW)
input_variables_source = None
if len(args) == 1 and hasattr(args[0], "__dict__"):
# is a proper object
input_variables_source = args[0]
elif len(args) > 1:
raise Exception(
f"Positional arguments are not supported for prompt functions. Only one positional argument as an object with attributes as a source of inputs is supported. Got: {args}")
if _capture_stream:
if "callbacks" in kwargs:
kwargs["callbacks"].append(
StreamingContext.StreamingContextCallback())
else:
kwargs["callbacks"] = [
StreamingContext.StreamingContextCallback()]
if "memory" in kwargs:
memory = kwargs.pop("memory")
else:
memory = None
prompt_template = PromptDecoratorTemplate.from_func(func,
template_format=template_format,
output_parser=output_parser,
format_instructions_parameter_key=format_instructions_parameter_key,
template_name=template_name,
template_version=template_version,
prompt_type=prompt_type,
)
if prompt_template.default_values:
kwargs = {**prompt_template.default_values, **kwargs}
llmChain = LLMChain(
llm=prompt_llm, prompt=prompt_template, memory=memory)
other_supported_kwargs = {"stop", "callbacks"}
unexpected_inputs = [
key for key in kwargs if key not in prompt_template.input_variables and key not in other_supported_kwargs]
if unexpected_inputs:
raise TypeError(
f"Unexpected inputs for prompt function {full_name}: {unexpected_inputs}. \nValid inputs are: {prompt_template.input_variables}")
missing_inputs = [
key for key in prompt_template.input_variables if key not in kwargs]
if format_instructions_parameter_key in missing_inputs:
missing_inputs.remove(format_instructions_parameter_key)
# init the format instructions with None... will be filled later
kwargs[format_instructions_parameter_key] = None
if missing_inputs:
if input_variables_source:
missing_value = {}
for key in missing_inputs:
value = getattr(input_variables_source,
key, missing_value)
if value is missing_value:
raise TypeError(
f"Missing a input for prompt function {full_name}: {key}.")
else:
kwargs[key] = value
else:
raise TypeError(
f"{full_name}: missing 1 required keyword-only argument: {missing_inputs}")
if stop_tokens:
kwargs["stop"] = stop_tokens
chain_args = kwargs
return llmChain, chain_args
def get_retry_parse_call_args(prompt_template: PromptDecoratorTemplate, exception: OutputParserExceptionWithOriginal):
logging.warning(
msg=f"Failed to parse output for {full_name}: {exception}\nRetrying...")
if format_instructions_parameter_key not in prompt_str:
logging.warning(
f"Please note that we didn't find a {format_instructions_parameter_key} parameter in the prompt string. If you don't include it in your prompt template, you need to provide your custom formatting instructions.")
retry_parse_template = PromptTemplate.from_template(
"This is an input {original} but it's not in correct format, please convert it into following format:\n{format_instructions}\n\nIf the input doesn't seem to be relevant to the expected format instructions, return 'N/A'")
prompt_llm = llm or GlobalSettings.get_current_settings().default_llm
retryChain = LLMChain(llm=prompt_llm, prompt=retry_parse_template)
format_instructions = prompt_template.output_parser.get_format_instructions()
if not format_instructions:
raise Exception(
f"Failed to get format instructions for {full_name} from output parser {prompt_template.output_parser}.")
call_kwargs = {"original": exception.original,
"format_instructions": format_instructions}
return retryChain, call_kwargs
if not is_async:
@wraps(func)
def wrapper(*args, **kwargs):
print_log(log_object=f"> Entering {name} prompt decorator chain",
log_level=prompt_type.log_level if prompt_type else logging.DEBUG, color=LogColors.WHITE_BOLD)
llmChain, chain_args = prepare_call_args(*args, **kwargs)
try:
result = llmChain.predict(**chain_args)
if verbose or prompt_type:
print_log(log_object=f"\nResult:\n{result}", log_level=prompt_type.log_level if verbose else 100,
color=prompt_type.color if prompt_type else LogColors.BLUE)
if llmChain.prompt.output_parser:
result = llmChain.prompt.output_parser.parse(result)
except OutputParserException as e:
if retry_on_output_parsing_error and isinstance(e, OutputParserExceptionWithOriginal):
prompt_template = llmChain.prompt
retryChain, call_kwargs = get_retry_parse_call_args(
prompt_template, e)
result = retryChain.predict(**call_kwargs)
parsed = prompt_template.output_parser.parse(result)
print_log(log_object=f"> Finished chain",
log_level=prompt_type.log_level if prompt_type else logging.DEBUG, color=LogColors.WHITE_BOLD)
return parsed
else:
raise e
return result
return wrapper
else:
@wraps(func)
async def async_wrapper(*args, **kwargs):
print_log(log_object=f"> Entering {name} prompt decorator chain",
log_level=prompt_type.log_level if prompt_type else logging.DEBUG, color=LogColors.WHITE_BOLD)
llmChain, chain_args = prepare_call_args(*args, **kwargs)
try:
result = await llmChain.apredict(**chain_args)
if verbose or prompt_type:
print_log(log_object=f"\nResult:\n{result}", log_level=prompt_type.log_level if not verbose else 100,
color=prompt_type.color if prompt_type else LogColors.BLUE)
if llmChain.prompt.output_parser:
result = llmChain.prompt.output_parser.parse(result)
print_log(log_object=f"> Finished chain",
log_level=prompt_type.log_level if prompt_type else logging.DEBUG, color=LogColors.WHITE_BOLD)
except OutputParserException as e:
if retry_on_output_parsing_error and isinstance(e, OutputParserExceptionWithOriginal):
prompt_template = llmChain.prompt
retryChain, call_kwargs = get_retry_parse_call_args(
prompt_template, e)
result = await retryChain.apredict(**call_kwargs)
parsed = prompt_template.output_parser.parse(result)
print_log(log_object=f"> Finished chain",
log_level=prompt_type.log_level if prompt_type else logging.DEBUG, color=LogColors.WHITE_BOLD)
return parsed
else:
raise e
return result
return async_wrapper
if func:
return decorator(func)
else:
return decorator
</file>
<file path=src/langchain_decorators/pydantic_helpers.py>
from pydantic.fields import ModelField
def get_field_type(field_info: ModelField):
_item_type=None
if field_info.type_==field_info.outer_type_:
_type=field_info.type_
elif list == getattr(field_info.outer_type_, '__origin__', None):
#is list
_type = list
elif dict == getattr(field_info.outer_type_, '__origin__', None):
_type=dict
else:
raise Exception(f"Unknown type: {field_info.annotation}")
return _type
def is_field_nullable(field_info: ModelField):
_nullable=field_info.allow_none
def get_field_item_type(field_info: ModelField):
if list == getattr(field_info.outer_type_, '__origin__', None):
return field_info.outer_type_.__args__[0]
</file>
<file path=src/langchain_decorators/output_parsers.py>
import datetime
import logging
from textwrap import dedent, indent
from typing import Dict, List, Type, TypeVar, Union
from venv import logger
from langchain import LLMChain, PromptTemplate
from langchain.output_parsers import PydanticOutputParser
from langchain.schema import BaseOutputParser, OutputParserException
import re
import json
from sqlalchemy import desc, null
import yaml
from pydantic import BaseModel, ValidationError
from pydantic.fields import ModelField
from .pydantic_helpers import *
class OutputParserExceptionWithOriginal(OutputParserException):
"""Exception raised when an output parser fails to parse the output of an LLM call."""
def __init__(self, message: str, original: str) -> None:
super().__init__(message)
self.original = original
def __str__(self) -> str:
return f"{super().__str__()}\nOriginal output:\n{self.original}"
class ListOutputParser(BaseOutputParser):
"""Class to parse the output of an LLM call to a list."""
@property
def _type(self) -> str:
return "list"
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
pattern = r"^[ \t]*(?:[\-\*\+]|\d+\.)[ \t]+(.+)$"
matches = re.findall(pattern, text, flags=re.MULTILINE)
if not matches and text:
logging.warning(
f"{self.__class__.__name__} : LLM returned {text} but we could not parse it into a list")
return matches
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
return "Return result a s bulleted list."
class JsonOutputParser(BaseOutputParser):
"""Class to parse the output of an LLM call to a list."""
@property
def _type(self) -> str:
return "json"
def parse(self, text: str) -> List[str]:
try:
# Greedy search for 1st json candidate.
match = re.search(r"\{.*\}", text.strip(),
re.MULTILINE | re.IGNORECASE | re.DOTALL)
json_str = ""
if match:
json_str = match.group()
json_dict = json.loads(json_str, strict=False)
return json_dict
except (json.JSONDecodeError) as e:
msg = f"Invalid JSON\n {text}\nGot: {e}"
raise OutputParserExceptionWithOriginal(msg, text)
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
return "Return result as a valid JSON"
T = TypeVar("T", bound=BaseModel)
class PydanticOutputParser(BaseOutputParser[T]):
"""Class to parse the output of an LLM call to a list."""
model: Type[T]
instructions_as_json_example: bool = True
def __init__(self, model: Type[T], instructions_as_json_example: bool = True):
super().__init__(model=model, instructions_as_json_example=instructions_as_json_example)
@property
def _type(self) -> str:
return "pydantic"
def parse(self, text: str) -> T:
try:
# Greedy search for 1st json candidate.
match = re.search(r"\{.*\}", text.strip(),
re.MULTILINE | re.IGNORECASE | re.DOTALL)
json_str = ""
if match:
json_str = match.group()
json_dict = json.loads(json_str, strict=False)
return self.model.parse_obj(json_dict)
except (json.JSONDecodeError) as e:
msg = f"Invalid JSON\n {text}\nGot: {e}"
raise OutputParserExceptionWithOriginal(msg, text)
except ValidationError as e:
raise OutputParserExceptionWithOriginal(
f"Data are not in correct format: {text}\nGot: {e}", text)
def get_json_example_description(self, model: Type[BaseModel], indentation_level=0):
field_descriptions = {}
for field, field_info in model.__fields__.items():
_item_type = None
if field_info.type_ == field_info.outer_type_:
_type = field_info.type_
elif list == getattr(field_info.outer_type_, '__origin__', None):
# is list
_type = list
_item_type = field_info.outer_type_.__args__[0]
elif dict == getattr(field_info.outer_type_, '__origin__', None):
_type = dict
else:
raise Exception(f"Unknown type: {field_info.annotation}")
_nullable = field_info.allow_none
_description = field_info.field_info.description
if issubclass(_type, BaseModel):
field_descriptions[field] = (
self.get_json_example_description(_type, indentation_level+1))
elif _type == str:
desc = f'\" {_get_str_field_description(field_info)} "'
field_descriptions[field] = (desc)
elif _type == datetime:
field_descriptions[field] = (
"an ISO formatted datetime string")
elif _type == list:
list_desc = f"[ {_description} ... list of {_item_type} ]"
field_descriptions[field] = (list_desc)
elif _type == dict:
dict_desc = f"{{ ... {_description} ... }}"
field_descriptions[field] = (dict_desc)
elif _type == int:
field_descriptions[field] = ("an integer")
elif _type == float:
field_descriptions[field] = ("a float number")
if _nullable:
field_descriptions[field] = field_descriptions[field] + f" or null"
lines = []
for field, field_info in model.__fields__.items():
desc_lines = "\n".join(
("\t"*indentation_level+line for line in field_descriptions[field].splitlines())).strip()
lines.append("\t"*indentation_level + f"\"{field}\": {desc_lines}")
return "{\n" + ",\n".join(lines) + "\n}"
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
if not self.instructions_as_json_example:
return "Return result as a valid JSON that matched this json schema definition:\n" + yaml.safe_dump(self.model.schema())
else:
return dedent(f"""```json\n{self.get_json_example_description(self.model)}\n```""").strip()
class CheckListParser(ListOutputParser):
"""Parses list a a dictionary... assume this format:
- KeyParma1: Value1
- KeyPara2: Value2
...
"""
def __init__(self, model: Type[T] = None):
self.model = model
@property
def _type(self) -> str:
return "checklist"
def get_instructions_for_model(self, model: Type[T]) -> str:
fields_bullets = []
for field in model.__fields__.values():
description = [field.field_info.description]
if field.field_info.extra.get("one_of"):
description += "one of these values: [ "
description += " | ".join(field.field_info.extra.get("one_of"))
description += " ]"
if field.field_info.extra.get("example"):
description += f"e.g. {field.field_info.extra.get('example')}"
if description:
description = " ".join(description)
else:
description = "?"
fields_bullets.append(f"- {field.name}: {description}")
def parse(self, text: str) -> Union[dict, T]:
"""Parse the output of an LLM call."""
pattern = r"^[ \t]*(?:[\-\*\+]|\d+\.)[ \t]+(.+)$"
matches = re.findall(pattern, text, flags=re.MULTILINE)
result = {}
for match in matches:
key, value = match.split(":", 1)
result[key.strip()] = value.strip()
return matches
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
res = "Return result a s bulleted list in this format:\n"
if self.model:
res += self.get_instructions_for_model(self.model)
else:
res += "\n- Key1: Value1\n- Key2: Value2\n- ..."
class MarkdownStructureParser(ListOutputParser):
model: Type[T] = None
level: int = 1
sections_parsers: Dict[str, Union[BaseOutputParser, dict]] = None
def __init__(self, model: Type[T] = None, sections_parsers: Dict[str, Union[dict, BaseOutputParser]] = None, level=1):
super().__init__(model=model, sections_parsers=sections_parsers, level=level)
if model:
for field, field_info in model.__fields__.items():
if sections_parsers and field in self.sections_parsers:
# if section parser was already provided, skip
if not type(self.sections_parsers.get(field)) == dict:
continue
field_type = get_field_type(field_info)
if get_field_type(field_info) == list:
item_type = get_field_item_type(field_info)
if item_type == str or item_type is None:
self.sections_parsers[field] = ListOutputParser()
else:
raise ValueError(
f"Unsupported item type {item_type} for property {model}.{field}. Only list of strings is supported.")
elif field_type == dict:
self.sections_parsers[field] = CheckListParser()
elif field_type and issubclass(field_type, BaseModel):
all_sub_str = all(True for sub_field_info in field_type.__fields__.values(
) if get_field_type(sub_field_info) == str)
if all_sub_str:
self.sections_parsers[field] = MarkdownStructureParser(
field_type, sections_parsers=sections_parsers.get(field), level=level+1)
else:
self.sections_parsers[field] = PydanticOutputParser(
model=field_type)
elif field_type == str:
self.sections_parsers[field] = None
else:
raise ValueError(
f"Unsupported type {field_type} for property {field}.")
elif sections_parsers:
for property, property_parser in sections_parsers.items():
if type(property_parser) == dict:
sections_parsers[property] = MarkdownStructureParser(
model=None, sections_parsers=property_parser, level=level+1)
elif type(property_parser) == str:
sections_parsers[property] = None
elif isinstance(property_parser, BaseOutputParser):
continue
else:
raise ValueError(
f"Unsupported type {model.__fields__[property].annotation} for property {property}. Use a dict or a pydantic model.")
else:
self.sections_parsers = {}
@property
def _type(self) -> str:
return "checklist"
def get_instructions_for_sections(self, model: Type[T] = None, sections_parsers: Dict[str, BaseOutputParser] = None) -> str:
section_instructions = []
if model:
for field, field_info in model.__fields__.items():
name: str = field_info.field_info.title or field
section_instructions.append(self.level*"#" + f" {name}")
if sections_parsers and sections_parsers.get(field):
section_instructions.append(
sections_parsers.get(field).get_format_instructions())
continue
else:
description = _get_str_field_description(field_info)
section_instructions.append(description)
else:
for section, parser in sections_parsers.items():
section_instructions.append(self.level*"#" + f" {section}")
if isinstance(parser, BaseOutputParser):
section_instructions.append(
parser.get_format_instructions())
else:
section_instructions.append("?")
return "\n\n".join(section_instructions)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
sections_separators = list(re.finditer(
r"^#+[ |\t]+(.*)$", text, flags=re.MULTILINE))
res = {}
for i, section_separator_match in enumerate(sections_separators):
section_name = section_separator_match.group(1)
if self.model:
section_name = next((field for field, field_info in self.model.__fields__.items() if field_info.field_info.title ==
section_name or field.lower() == section_name.lower() or field_info.alias == section_name), section_name)
if i < len(sections_separators)-1:
section_content = text[section_separator_match.end(
):sections_separators[i+1].start()]
else:
section_content = text[section_separator_match.end():]
parsed_content = None
if self.sections_parsers and self.sections_parsers.get(section_name, None) or self.sections_parsers.get(section_separator_match.group(1)):
parser = self.sections_parsers.get(
section_name, None) or self.sections_parsers.get(section_separator_match.group(1))
if isinstance(parser, BaseOutputParser):
parsed_content = parser.parse(section_content)
if not parsed_content:
parsed_content = section_content.strip()
res[section_name] = parsed_content
if self.model:
try:
return self.model(**res)
except ValidationError as e:
raise OutputParserExceptionWithOriginal(
f"Data are not in correct format: {text}\nGot: {e}", text)
else:
return res
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
res = "Return result as a markdown in this format:\n"
if self.model or self.sections_parsers:
res += self.get_instructions_for_sections(
self.model, self.sections_parsers)
else:
res += "# Section 1\n\ndescription\n\n#Section 2\n\ndescription\n\..."
return res
def _get_str_field_description(field_info: ModelField, ignore_nullable: bool = False):
_nullable = field_info.allow_none
_description = field_info.field_info.description
_example = field_info.field_info.extra.get("example")
_one_of = field_info.field_info.extra.get("one_of")
_regex = field_info.field_info.extra.get("regex")
_one_of = field_info.field_info.extra.get("one_of")
_regex = field_info.field_info.extra.get("regex")
description = []
if _description:
description.append(_description)
if _one_of:
description += "one of these values: [ "
description += " | ".join(_one_of)
description += " ]"
if _example:
description += f"e.g. {_example}"
if _nullable and not ignore_nullable:
description += "... or 'N/A' if not available"
if _regex and not _one_of:
description += f"... must match this regex: {_regex}"
if description:
description = " ".join(description)
else:
description = "?"
return description
</file>
<file path=src/langchain_decorators/prompt_template.py>
import logging
import re
import inspect
from string import Formatter
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
from pydantic import BaseModel
from textwrap import dedent
from langchain import PromptTemplate
from langchain.prompts import StringPromptTemplate
from langchain.prompts.chat import MessagesPlaceholder, ChatMessagePromptTemplate, ChatPromptTemplate, ChatPromptValue
from langchain.schema import PromptValue, BaseOutputParser
from .common import GlobalSettings, LogColors, PromptTypeSettings, print_log
from .output_parsers import *
def parse_prompts_from_docs(docs: str):
prompts = []
for i, prompt_block in enumerate(re.finditer(r"```[^\S\n]*<prompt(?P<role>:\w+)?>\n(?P<prompt>.*?)\n```[ |\t]*\n", docs, re.MULTILINE | re.DOTALL)):
role = prompt_block.group("role")
prompt = prompt_block.group("prompt")
# remove \ escape before ```
prompt = re.sub(r"((?<=\s)\\(?=```))|^\\(?=```)",
"", prompt, flags=re.MULTILINE)
prompt.strip()
if not role:
if i > 1:
raise ValueError(
"Only one prompt can be defined in code block. If you intend to define messages, you need to specify a role.\nExample:\n```<prompt:role>\nFoo {bar}\n```")
else:
prompts.append(prompt)
else:
prompts.append((role[1:], prompt))
if not prompts:
# the whole document is a prompt
prompts.append(docs.strip())
return prompts
class PromptTemplateDraft(BaseModel):
role: str = None
input_variables: List[str]
template: str
partial_variables_builders: Optional[Dict[str, Callable[[dict], str]]]
def finalize_template(self, input_variable_values: dict) -> Union[MessagesPlaceholder, ChatMessagePromptTemplate, StringPromptTemplate]:
if self.role == "placeholder":
return MessagesPlaceholder(variable_name=self.input_variables[0])
else:
final_template_string = self.template
if self.partial_variables_builders:
for final_partial_key, partial_builder in self.partial_variables_builders.items():
final_partial_value = partial_builder(
input_variable_values)
final_template_string = final_template_string.replace(
f"{{{final_partial_key}}}", final_partial_value)
final_template_string = final_template_string.strip()
content_template = PromptTemplate.from_template(
final_template_string)
if self.role:
return ChatMessagePromptTemplate(role=self.role, prompt=content_template)
else:
return content_template
def build_template_drafts(template: str, format: str, role: str = None) -> PromptTemplateDraft:
partials_with_params = {}
if role != "placeholder" and format == "f-string-extra":
optional_blocks_regex = list(re.finditer(
r"\{\?(?P<optional_partial>.+?)(?=\?\})\?\}", template, re.MULTILINE | re.DOTALL))
for optional_block in optional_blocks_regex:
optional_partial = optional_block.group("optional_partial")
partial_input_variables = {v for _, v, _, _ in Formatter().parse(
optional_partial) if v is not None}
if not partial_input_variables:
raise ValueError(
f"Optional partial {optional_partial} does not contain any optional variables.")
# replace {} with [] and all other non-word characters with underscore
partial_name = re.sub(
r"[^\w\[\]]+", "_", optional_partial.replace("{", "[").replace("}", "]"))
partials_with_params[partial_name] = (
optional_partial, partial_input_variables)
# replace optional partial with a placeholder
template = template.replace(
optional_block.group(0), f"{{{partial_name}}}")
# partial_name: a function that takes in a dict of variables and returns a string...
partial_builders = {}
for partial_name, (partial, partial_input_variables) in partials_with_params.items():
# create function that will render the partial if all the input variables are present. Otherwise, it will return an empty string...
# it needs to be unique for each partial, since we check only for the variables that are present in the partial
def partial_formatter(inputs, _partial=partial):
""" This will render the partial if all the input variables are present. Otherwise, it will return an empty string."""
missing_param = next(
(param for param in partial_input_variables if param not in inputs or not inputs[param]), None)
if missing_param:
return ""
else:
return _partial
partial_builders[partial_name] = partial_formatter
input_variables = [v for _, v, _, _ in Formatter().parse(
template) if v is not None and v not in partials_with_params]
for partial_name, (partial, partial_input_variables) in partials_with_params.items():
input_variables.extend(partial_input_variables)
input_variables = list(set(input_variables))
if not partials_with_params:
partials_with_params = None
partial_builders = None
if not role:
return PromptTemplateDraft(input_variables=input_variables, template=template, partial_variables_builders=partial_builders)
elif role == "placeholder":
if len(input_variables) > 1:
raise ValueError(
f"Placeholder prompt can only have one input variable, got {input_variables}")
elif len(input_variables) == 0:
raise ValueError(
f"Placeholder prompt must have one input variable, got none.")
return PromptTemplateDraft(template=template, input_variables=input_variables, partial_variables_builders=partial_builders, role="placeholder")
else:
return PromptTemplateDraft(role=role, input_variables=input_variables, template=template, partial_variables_builders=partial_builders)
class PromptDecoratorTemplate(StringPromptTemplate):
template_string: str
prompt_template_drafts: Union[PromptTemplateDraft,
List[PromptTemplateDraft]]
template_name: str
template_format: str
optional_variables: List[str]
optional_variables_none_behavior: str
default_values: Dict[str, Any]
format_instructions_parameter_key: str
template_version: str = None
prompt_type: PromptTypeSettings = None
@classmethod
def build(cls,
template_string: str,
template_name: str,
template_format: str = "f-string-extra",
output_parser: Union[None, BaseOutputParser] = None,
optional_variables: Optional[List[str]] = None,
optional_variables_none_behavior: str = "skip_line",
default_values: Optional[Dict[str, Any]] = None,
format_instructions_parameter_key: str = "FORMAT_INSTRUCTIONS",
template_version: str = None,
prompt_type: PromptTypeSettings = None
) -> "PromptDecoratorTemplate":
if template_format not in ["f-string", "f-string-extra"]:
raise ValueError(
f"template_format must be one of [f-string, f-string-extra], got {template_format}")
prompts = parse_prompts_from_docs(template_string)
if isinstance(prompts, list):
prompt_template_drafts = []
input_variables = []
for prompt in prompts:
if isinstance(prompt, str):
prompt_template_drafts = build_template_drafts(
prompt, format=template_format)
input_variables = prompt_template_drafts.input_variables
# there should be only one prompt if it's a string
break
else:
(role, content_template) = prompt
message_template = build_template_drafts(
content_template, format=template_format, role=role)
input_variables.extend(message_template.input_variables)
prompt_template_drafts.append(message_template)
return cls(
input_variables=input_variables, # defined in base
output_parser=output_parser, # defined in base
prompt_template_drafts=prompt_template_drafts,
template_name=template_name,
template_version=template_version,
template_string=template_string,
template_format=template_format,
optional_variables=optional_variables,
optional_variables_none_behavior=optional_variables_none_behavior,
default_values=default_values,
format_instructions_parameter_key=format_instructions_parameter_key,
prompt_type=prompt_type
)
@classmethod
def from_func(cls,
func: Union[Callable, Coroutine],
template_name: str = None,
template_version: str = None,
output_parser: Union[str, None, BaseOutputParser] = "auto",
template_format: str = "f-string-extra",
format_instructions_parameter_key: str = "FORMAT_INSTRUCTIONS",
prompt_type: PromptTypeSettings = None
) -> "PromptDecoratorTemplate":
fist_line, rest = func.__doc__.split('\n', 1)
# we dedent the first line separately,because its common that it often starts right after """
fist_line = fist_line.strip()
if fist_line:
fist_line += "\n"
template_string = fist_line + dedent(rest)
template_name = template_name or f"{func.__module__}.{func.__name__}" if not func.__module__ == "__main__" else func.__name__
return_type = func.__annotations__.get("return", None)
if inspect.iscoroutinefunction(func):
if return_type and issubclass(return_type, Coroutine):
return_type_args = getattr(return_type, '__args__', None)
if return_type_args and len(return_type_args) == 3:
return_type = return_type_args[2]
else:
raise Exception(
f"Invalid Coroutine annotation {return_type}. Expected Coroutine[ any , any, <return_type>] or just <return_type>")
if output_parser == "auto":
if return_type == str or return_type == None:
output_parser = "str"
elif return_type == dict:
output_parser = "json"
elif return_type == list:
output_parser = "list"
elif issubclass(return_type, BaseModel):
output_parser = PydanticOutputParser(model=return_type)
else:
raise Exception(f"Unsupported return type {return_type}")
if isinstance(output_parser, str):
if output_parser == "str":
output_parser = None
elif output_parser == "json":
output_parser = JsonOutputParser()
elif output_parser == "markdown":
if return_type and return_type != dict:
raise Exception(
f"Conflicting output parsing instructions. Markdown output parser only supports return type dict, got {return_type}.")
else:
output_parser = MarkdownStructureParser()
elif output_parser == "list":
output_parser = ListOutputParser()
elif output_parser == "pydantic":
if issubclass(return_type, BaseModel):
output_parser = PydanticOutputParser(model=return_type)
elif return_type == None:
raise Exception(
f"You must annotate the return type for pydantic output parser, so that we can infer the model")
else:
raise Exception(
f"Unsupported return type {return_type} for pydantic output parser")
default_values = {k: v.default for k, v in inspect.signature(
func).parameters.items() if v.default != inspect.Parameter.empty}
return cls.build(
template_string=template_string,
template_name=template_name,
template_version=template_version,
output_parser=output_parser,
template_format=template_format,
optional_variables=[*default_values.keys()],
default_values=default_values,
format_instructions_parameter_key=format_instructions_parameter_key,
prompt_type=prompt_type
)
def get_final_template(self, **kwargs: Any) -> PromptTemplate:
"""Create Chat Messages."""
if self.default_values:
# if we have default values, we will use them to fill in missing values
kwargs = {**self.default_values, **kwargs}
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if isinstance(self.prompt_template_drafts, list):
message_templates = []
for message_draft in self.prompt_template_drafts:
msg_template = message_draft.finalize_template(kwargs)
message_templates.append(msg_template)
template = ChatPromptTemplate(
messages=message_templates, input_variables=self.input_variables, output_parser=self.output_parser)
else:
template = self.prompt_template_drafts.finalize_template(kwargs)
template.output_parser = self.output_parser
return template
def format_prompt(self, **kwargs: Any) -> PromptValue:
if self.format_instructions_parameter_key in self.input_variables and not kwargs.get(self.format_instructions_parameter_key) and self.output_parser:
# add format instructions to inputs
kwargs[self.format_instructions_parameter_key] = self.output_parser.get_format_instructions()
final_template = self.get_final_template(**kwargs)
kwargs = {k: v for k, v in kwargs.items(
) if k in final_template.input_variables}
if isinstance(final_template, ChatPromptTemplate):
for msg in list(final_template.messages):
if isinstance(msg, MessagesPlaceholder):
if not kwargs.get(msg.variable_name):
kwargs[msg.variable_name] = []
formatted = final_template.format_prompt(**kwargs)
if isinstance(formatted, ChatPromptValue):
for msg in list(formatted.messages):
if not msg.content or not msg.content.strip():
formatted.messages.remove(msg)
self.on_prompt_formatted(formatted.to_string())
return formatted
def format(self, **kwargs: Any) -> str:
formatted = self.get_final_template(**kwargs).format(**kwargs)
self.on_prompt_formatted(formatted)
return formatted
def on_prompt_formatted(self, formatted: str):
if not self.prompt_type:
log_level = logging.DEBUG
log_color = LogColors.DARK_GRAY
else:
log_level = self.prompt_type.log_level
log_color = self.prompt_type.color
print_log(f"Prompt:\n{formatted}", log_level, log_color)
</file>
<error>
Traceback (most recent call last):
File "/workspaces/langchain-decorators/src/test.py", line 6, in <module>
from langchain_decorators import GlobalSettings
File "/workspaces/langchain-decorators/src/langchain_decorators/__init__.py", line 1, in <module>
from .common import LogColors, GlobalSettings, print_log, PromptTypes, PromptTypeSettings
File "/workspaces/langchain-decorators/src/langchain_decorators/common.py", line 12, in <module>
class GlobalSettings(BaseModel):
File "/workspaces/langchain-decorators/src/langchain_decorators/common.py", line 27, in GlobalSettings
default_llm=ChatOpenAI(temperature=0.0),
File "pydantic/main.py", line 341, in pydantic.main.BaseModel.__init__
pydantic.error_wrappers.ValidationError: 1 validation error for ChatOpenAI
__root__
Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. (type=value_error)
</error>
Suggest how to fix the common.py library so there is no error when running test.py
The ChatOpenAI method should not be loaded if the config is set to use another LLM class.
I don't want to use ChatOpenAI at all but others will. It is ok as a default but I must be able to use ChatAnthropic instead if I set:
```from langchain_decorators import GlobalSettings
GlobalSettings.define_settings(
default_llm=ChatAnthropic(temperature=0.0),
default_streaming_llm=ChatAnthropic(temperature=0.0, streaming=True),
)```
Let's think step by step about how to prevent ChatOpenAI from loading when define_settings has a different LLM set.
Assistant: To fix this, you should:
1. Make the default_llm and default_streaming_llm fields optional in the GlobalSettings model:
```python
default_llm: Optional[BaseLanguageModel] = None
default_streaming_llm: Optional[BaseLanguageModel] = None
```
2. In the define_settings method, check if a default_llm is passed. If so, use that. Otherwise, default to ChatOpenAI:
```python
def define_settings(cls,
default_llm=None,
default_streaming_llm=None,
...):
if default_llm is None:
default_llm = ChatOpenAI(temperature=0.0)
if default_streaming_llm is None:
default_streaming_llm = ChatOpenAI(temperature=0.0, streaming=True)
settings = cls(default_llm=default_llm, default_streaming_llm=default_streaming_llm, ...)
```
3. Remove the ChatOpenAI import from line 12, since it is now only conditionally used in define_settings.
4. Add a check in the __init__ method to only load the default LLM if it is set:
```python
def __init__(self, default_llm=None, default_streaming_llm=None, ...):
...
if default_llm:
self.default_llm = default_llm
if default_streaming_llm:
self.default_streaming_llm = default_streaming_llm
```
With these changes, ChatOpenAI will only be loaded if define_settings is called without specifying an LLM, fixing the error. Let me know if this helps explain the issue!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment