Last active
May 17, 2024 08:32
-
-
Save rogeriochaves/3bb334c57ea5a095c9cf5cc6c53bc499 to your computer and use it in GitHub Desktop.
monadic langchain, a more FP interface to langchain
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is how langchain chain composition looks with a more FP approach | |
import re | |
from typing import ( | |
Literal, | |
TypedDict, | |
cast, | |
) | |
from dotenv import load_dotenv | |
from monadic import Chain, ConstantChain | |
load_dotenv() | |
import langchain | |
import langchain.schema | |
from langchain.chat_models import ChatOpenAI | |
# Setup | |
langchain.debug = True | |
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0) | |
def simple_key_extract(key: str, output: str) -> str: | |
found = re.search(f"{key}: (.*)", output) | |
if found is None: | |
raise Exception("Parsing Error") | |
return found[1] | |
# Example | |
class RoutingChainOutput(TypedDict): | |
action: Literal["SEARCH", "REPLY"] | |
param: str | |
routing_chain = Chain[str, RoutingChainOutput]( | |
"RoutingChain", | |
llm=llm, | |
prompt=""" | |
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them. | |
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \ | |
for everything else, so you can actually do the search and help them. | |
============================= | |
Input: hello there | |
Action: REPLY | |
Param: hey there, what are you looking for? | |
Input: how does langchain work? | |
Action: SEARCH | |
Param: langchain how it works | |
Input: code example of vector db | |
Action: SEARCH | |
Param: vector db code example | |
Input: how is it going? | |
Action: REPLY | |
Param: I'm going well, how about you? | |
Input: {question} | |
""", | |
input_mapper=lambda input: {"question": input}, | |
output_parser=lambda output: { | |
"action": cast( | |
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output) | |
), | |
"param": simple_key_extract("Param", output), | |
}, | |
) | |
summarizer_chain = Chain[str, str]( | |
"SummarizerChain", | |
llm=llm, | |
prompt="Summarize the following text: {text}\nSummary: ", | |
input_mapper=lambda input: {"text": input}, | |
) | |
search_chain = Chain[RoutingChainOutput, str]( | |
"SearchChain", | |
llm=llm, | |
prompt="Pretend to search for the user. Query: {query}\nResults: ", # this would be replace with a proper vector db search | |
input_mapper=lambda input: {"query": input["param"]}, | |
).and_then(lambda _: summarizer_chain) | |
conversation_chain = routing_chain.and_then( | |
lambda output: ConstantChain(output["param"]) | |
if output["action"] == "REPLY" | |
else search_chain | |
) | |
import chainlit as cl | |
@cl.langchain_factory(use_async=False) | |
def factory(): | |
return conversation_chain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For comparison, this is how the same functionality looks with default langchain interface | |
import re | |
from typing import Any, Dict, Literal, TypedDict, cast | |
from dotenv import load_dotenv | |
load_dotenv() | |
import langchain | |
import langchain.schema | |
from langchain.chat_models import ChatOpenAI | |
from langchain import ( | |
LLMChain, | |
PromptTemplate, | |
) | |
from langchain.schema import BaseOutputParser | |
from langchain.callbacks.manager import ( | |
Callbacks, | |
) | |
import chainlit as cl | |
# Setup | |
langchain.debug = True | |
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0) | |
def simple_key_extract(key: str, output: str) -> str: | |
found = re.search(f"{key}: (.*)", output) | |
if found is None: | |
raise Exception("Parsing Error") | |
return found[1] | |
# Example | |
class RoutingChainOutput(TypedDict): | |
action: Literal["SEARCH", "REPLY"] | |
param: str | |
class RoutingParser(BaseOutputParser[RoutingChainOutput]): | |
def parse(self, output: str) -> Dict[str, Any]: | |
return { | |
"action": cast( | |
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output) | |
), | |
"param": simple_key_extract("Param", output), | |
} | |
class RoutingChain(LLMChain): | |
def __init__(self): | |
return super().__init__( | |
llm=llm, | |
prompt=PromptTemplate( | |
template=""" | |
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them. | |
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \ | |
for everything else, so you can actually do the search and help them. | |
============================= | |
Input: hello there | |
Action: REPLY | |
Param: hey there, what are you looking for? | |
Input: how does langchain work? | |
Action: SEARCH | |
Param: langchain how it works | |
Input: code example of vector db | |
Action: SEARCH | |
Param: vector db code example | |
Input: how is it going? | |
Action: REPLY | |
Param: I'm going well, how about you? | |
Input: {question} | |
""", | |
input_variables=["question"], | |
output_parser=RoutingParser(), | |
), | |
) | |
class SummarizerChain(LLMChain): | |
def __init__(self): | |
return super().__init__( | |
llm=llm, | |
prompt=PromptTemplate( | |
template="Summarize the following text: {text}\nSummary:", | |
input_variables=["text"], | |
), | |
) | |
class SearchChain(LLMChain): | |
def __init__(self): | |
return super().__init__( | |
llm=llm, | |
prompt=PromptTemplate( | |
template="Pretend to search for the user. Query: {query}\nResults: ", | |
input_variables=["query"], | |
), | |
) | |
def conversation(input: str, callbacks: Callbacks) -> str: | |
route = cast( | |
RoutingChainOutput, | |
RoutingChain().predict_and_parse(callbacks=callbacks, question=input), | |
) | |
if route["action"] == "REPLY": | |
return route["param"] | |
elif route["action"] == "SEARCH": | |
result = SearchChain().__call__({"query": route["param"]}, callbacks=callbacks) | |
result = SummarizerChain().__call__( | |
{"text": result["text"]}, callbacks=callbacks | |
) | |
return result["text"] | |
else: | |
return f"unknown action {route['action']}" | |
@cl.langchain_factory(use_async=False) | |
def factory(): | |
return conversation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is the interface implementation | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Generic, | |
Optional, | |
Type, | |
TypeVar, | |
) | |
import langchain | |
import langchain.schema | |
from langchain.base_language import BaseLanguageModel | |
from langchain.callbacks.manager import Callbacks | |
T = TypeVar("T") | |
U = TypeVar("U") | |
V = TypeVar("V") | |
# Monadic Chain | |
def identity_parser(x: str): | |
return x | |
class Chain(Generic[T, U]): | |
llm: BaseLanguageModel | |
prompt: str | |
input_mapper: Optional[Callable[[T], Dict[str, str]]] | |
output_parser: Callable[[str], U] | |
named_class: Type[langchain.LLMChain] | |
def __init__( | |
self, | |
name: str, | |
llm: BaseLanguageModel, | |
prompt: str, | |
input_mapper: Optional[Callable[[T], Dict[str, str]]] = None, | |
output_parser: Callable[[str], U] = identity_parser, | |
) -> None: | |
self.llm = llm | |
self.prompt = prompt | |
self.input_mapper = input_mapper | |
self.output_parser = output_parser | |
self.named_class = type(name, (langchain.LLMChain,), {}) | |
def call(self, input: T, callbacks: Callbacks = []) -> U: | |
input_values = self.input_mapper(input) if self.input_mapper else input | |
if type(input_values) is not dict: | |
raise Exception("cannot extract input_values") | |
prompt = langchain.PromptTemplate( | |
template=self.prompt, | |
input_variables=[str(k) for k in input_values.keys()], | |
) | |
chain = self.named_class(llm=self.llm, prompt=prompt, callbacks=callbacks) | |
result = chain.run(input_values) | |
result = self.output_parser(result) | |
return result | |
def __call__(self, input: T, callbacks: Callbacks = []) -> U: | |
return self.call(input, callbacks) | |
def and_then(self, fn: Callable[[U], "Chain[U, V]"]): | |
return PipeChain(chain_a=self, chain_b_fn=fn) | |
class IdentityChain(Chain[T, T]): | |
def __init__(self) -> None: | |
pass | |
def call(self, input: T, _callbacks: Callbacks = []): | |
return input | |
class ConstantChain(Chain[Any, U]): | |
output: U | |
def __init__(self, output: U) -> None: | |
self.output = output | |
def call(self, _input: Any, _callbacks: Callbacks = []): | |
return self.output | |
class PipeChain(Chain[T, V]): | |
chain_a: Chain[T, Any] | |
chain_b_fn: Callable[[Any], Chain[Any, V]] | |
def __init__( | |
self, chain_a: Chain[T, U], chain_b_fn: Callable[[U], Chain[Any, V]] | |
) -> None: | |
self.chain_a = chain_a | |
self.chain_b_fn = chain_b_fn | |
def call(self, input: T, callbacks: Callbacks = []): | |
output = self.chain_a.call(input, callbacks) | |
chain_b = self.chain_b_fn(output) | |
return chain_b.call(output, callbacks) |
Hi, do you think this is still relevant now that they have LCEL?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is the message on discord that inspired me to play with that:
Comparing the FP version I drafted with the OO version (by OO read the original langchain interface, I know it could have some improvements even staying OOish):
On the FP one we have real composition, using
and_then
we connect the chains, and there is no need to imperative plumbing, in the OO version we had to build theconversation
function which calls the chains in order, plumbing the inputs and outputs from one chain to the next, imperativelyThis would be okay, and imperative code might be even easier to follow anyway, however, the responsibility of converting data input now lives in the "flow manager" (
conversation
function), which can be a bit distant from the chain definitions, and if input definitions change there, we will have to change the flow manager as well. On the FP version, theinput_mapper
lives on the chain itself.And we have no type checks there, because we use just string lists for inputs and outputs on the OO version, if an input doesn't match an output things fails at runtime, Python types are not the most powerful or strict, but in the FP version I managed to have it all very nicely typed, enough for the IDE to preventing me from connecting two chains when their input/output do not match, much better dev ergonomics
The OO version have implicit inputs and outputs, for example, if you are not explicit about the
output_parser
and explicitly callingpredict_and_parse
, you havetext
as output key, other built-in chains have other input and output keys, and it's hard to figure out what they are, because there is no type safety. In the FP type it's all explicit right at the creation at the chain, and type hints of the editor, you don't have to spend time digging into the implementation to find it (forget about docs, it's also not there). To try to use type safety on OO version I had to forcefullycast
it on theRoutingChainOutput
, since by default it only have very open str lists and dicts as input and outputsTo ease the debug, and get all the benefits of
chainlit
for example, it's better to have named chains and pass down the callbacks. In the OO, you always need to create a class inheritance to do that, and manually pass down the callbacks, changing the behaviour of a chain is also tiresome by overwritting both the call and acall methods and not forgetting to copy and paste some boilerplate to keep the callbacks. In the FP version, since we don't need to do manual plumbing and can useand_then
, the callback forwarding can stay inside the lib, without needing to use a provided composer likeSequentialChain
orRouterChain
, it's trivial to have your own chains in sequence, or fork, without specialized classes, this is because in the FP version we can follow the monad lawsIn the FP version I can define everything together when creating the chain, no need to subinstance of PromptTemplate or external class inheritance of RoutingParser for example, we can just use lambdas, this makes the code more co-located and easier to see the whole chain definition at a glance
In the OO version, however, still feel more pythonic and at some points typing the FP version becomes a bit of a struggle. Python doesn't provide good support for typing lambdas, or for having anonymous typed inline dicts (although there are proposals), etc, making it easier to just follow the standard python way to have classes and named methods and so on
Even though langchain by default has this more OO interface, I am not sure (at least for the things I'm doing) if it holds mutable state anywhere (other than very explicit memory for example), so it anyway feels more like it is an OO clothing but without having state in the objects (which is a good thing maybe?)
As a next step, we can erase the differences between LLMs and Agents, having them to compose seamlessly, as mentioned on the screenshot above
More thoughts?