Skip to content

Instantly share code, notes, and snippets.

@rogeriochaves
Last active May 17, 2024 08:32
Show Gist options
  • Save rogeriochaves/3bb334c57ea5a095c9cf5cc6c53bc499 to your computer and use it in GitHub Desktop.
Save rogeriochaves/3bb334c57ea5a095c9cf5cc6c53bc499 to your computer and use it in GitHub Desktop.
monadic langchain, a more FP interface to langchain
# This is how langchain chain composition looks with a more FP approach
import re
from typing import (
Literal,
TypedDict,
cast,
)
from dotenv import load_dotenv
from monadic import Chain, ConstantChain
load_dotenv()
import langchain
import langchain.schema
from langchain.chat_models import ChatOpenAI
# Setup
langchain.debug = True
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0)
def simple_key_extract(key: str, output: str) -> str:
found = re.search(f"{key}: (.*)", output)
if found is None:
raise Exception("Parsing Error")
return found[1]
# Example
class RoutingChainOutput(TypedDict):
action: Literal["SEARCH", "REPLY"]
param: str
routing_chain = Chain[str, RoutingChainOutput](
"RoutingChain",
llm=llm,
prompt="""
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them.
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \
for everything else, so you can actually do the search and help them.
=============================
Input: hello there
Action: REPLY
Param: hey there, what are you looking for?
Input: how does langchain work?
Action: SEARCH
Param: langchain how it works
Input: code example of vector db
Action: SEARCH
Param: vector db code example
Input: how is it going?
Action: REPLY
Param: I'm going well, how about you?
Input: {question}
""",
input_mapper=lambda input: {"question": input},
output_parser=lambda output: {
"action": cast(
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output)
),
"param": simple_key_extract("Param", output),
},
)
summarizer_chain = Chain[str, str](
"SummarizerChain",
llm=llm,
prompt="Summarize the following text: {text}\nSummary: ",
input_mapper=lambda input: {"text": input},
)
search_chain = Chain[RoutingChainOutput, str](
"SearchChain",
llm=llm,
prompt="Pretend to search for the user. Query: {query}\nResults: ", # this would be replace with a proper vector db search
input_mapper=lambda input: {"query": input["param"]},
).and_then(lambda _: summarizer_chain)
conversation_chain = routing_chain.and_then(
lambda output: ConstantChain(output["param"])
if output["action"] == "REPLY"
else search_chain
)
import chainlit as cl
@cl.langchain_factory(use_async=False)
def factory():
return conversation_chain
# For comparison, this is how the same functionality looks with default langchain interface
import re
from typing import Any, Dict, Literal, TypedDict, cast
from dotenv import load_dotenv
load_dotenv()
import langchain
import langchain.schema
from langchain.chat_models import ChatOpenAI
from langchain import (
LLMChain,
PromptTemplate,
)
from langchain.schema import BaseOutputParser
from langchain.callbacks.manager import (
Callbacks,
)
import chainlit as cl
# Setup
langchain.debug = True
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0)
def simple_key_extract(key: str, output: str) -> str:
found = re.search(f"{key}: (.*)", output)
if found is None:
raise Exception("Parsing Error")
return found[1]
# Example
class RoutingChainOutput(TypedDict):
action: Literal["SEARCH", "REPLY"]
param: str
class RoutingParser(BaseOutputParser[RoutingChainOutput]):
def parse(self, output: str) -> Dict[str, Any]:
return {
"action": cast(
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output)
),
"param": simple_key_extract("Param", output),
}
class RoutingChain(LLMChain):
def __init__(self):
return super().__init__(
llm=llm,
prompt=PromptTemplate(
template="""
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them.
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \
for everything else, so you can actually do the search and help them.
=============================
Input: hello there
Action: REPLY
Param: hey there, what are you looking for?
Input: how does langchain work?
Action: SEARCH
Param: langchain how it works
Input: code example of vector db
Action: SEARCH
Param: vector db code example
Input: how is it going?
Action: REPLY
Param: I'm going well, how about you?
Input: {question}
""",
input_variables=["question"],
output_parser=RoutingParser(),
),
)
class SummarizerChain(LLMChain):
def __init__(self):
return super().__init__(
llm=llm,
prompt=PromptTemplate(
template="Summarize the following text: {text}\nSummary:",
input_variables=["text"],
),
)
class SearchChain(LLMChain):
def __init__(self):
return super().__init__(
llm=llm,
prompt=PromptTemplate(
template="Pretend to search for the user. Query: {query}\nResults: ",
input_variables=["query"],
),
)
def conversation(input: str, callbacks: Callbacks) -> str:
route = cast(
RoutingChainOutput,
RoutingChain().predict_and_parse(callbacks=callbacks, question=input),
)
if route["action"] == "REPLY":
return route["param"]
elif route["action"] == "SEARCH":
result = SearchChain().__call__({"query": route["param"]}, callbacks=callbacks)
result = SummarizerChain().__call__(
{"text": result["text"]}, callbacks=callbacks
)
return result["text"]
else:
return f"unknown action {route['action']}"
@cl.langchain_factory(use_async=False)
def factory():
return conversation
# This is the interface implementation
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
)
import langchain
import langchain.schema
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")
# Monadic Chain
def identity_parser(x: str):
return x
class Chain(Generic[T, U]):
llm: BaseLanguageModel
prompt: str
input_mapper: Optional[Callable[[T], Dict[str, str]]]
output_parser: Callable[[str], U]
named_class: Type[langchain.LLMChain]
def __init__(
self,
name: str,
llm: BaseLanguageModel,
prompt: str,
input_mapper: Optional[Callable[[T], Dict[str, str]]] = None,
output_parser: Callable[[str], U] = identity_parser,
) -> None:
self.llm = llm
self.prompt = prompt
self.input_mapper = input_mapper
self.output_parser = output_parser
self.named_class = type(name, (langchain.LLMChain,), {})
def call(self, input: T, callbacks: Callbacks = []) -> U:
input_values = self.input_mapper(input) if self.input_mapper else input
if type(input_values) is not dict:
raise Exception("cannot extract input_values")
prompt = langchain.PromptTemplate(
template=self.prompt,
input_variables=[str(k) for k in input_values.keys()],
)
chain = self.named_class(llm=self.llm, prompt=prompt, callbacks=callbacks)
result = chain.run(input_values)
result = self.output_parser(result)
return result
def __call__(self, input: T, callbacks: Callbacks = []) -> U:
return self.call(input, callbacks)
def and_then(self, fn: Callable[[U], "Chain[U, V]"]):
return PipeChain(chain_a=self, chain_b_fn=fn)
class IdentityChain(Chain[T, T]):
def __init__(self) -> None:
pass
def call(self, input: T, _callbacks: Callbacks = []):
return input
class ConstantChain(Chain[Any, U]):
output: U
def __init__(self, output: U) -> None:
self.output = output
def call(self, _input: Any, _callbacks: Callbacks = []):
return self.output
class PipeChain(Chain[T, V]):
chain_a: Chain[T, Any]
chain_b_fn: Callable[[Any], Chain[Any, V]]
def __init__(
self, chain_a: Chain[T, U], chain_b_fn: Callable[[U], Chain[Any, V]]
) -> None:
self.chain_a = chain_a
self.chain_b_fn = chain_b_fn
def call(self, input: T, callbacks: Callbacks = []):
output = self.chain_a.call(input, callbacks)
chain_b = self.chain_b_fn(output)
return chain_b.call(output, callbacks)
@rogeriochaves
Copy link
Author

rogeriochaves commented Jun 17, 2023

This is the message on discord that inspired me to play with that:

image

Comparing the FP version I drafted with the OO version (by OO read the original langchain interface, I know it could have some improvements even staying OOish):

  1. On the FP one we have real composition, using and_then we connect the chains, and there is no need to imperative plumbing, in the OO version we had to build the conversation function which calls the chains in order, plumbing the inputs and outputs from one chain to the next, imperatively

  2. This would be okay, and imperative code might be even easier to follow anyway, however, the responsibility of converting data input now lives in the "flow manager" (conversation function), which can be a bit distant from the chain definitions, and if input definitions change there, we will have to change the flow manager as well. On the FP version, the input_mapper lives on the chain itself.

  3. And we have no type checks there, because we use just string lists for inputs and outputs on the OO version, if an input doesn't match an output things fails at runtime, Python types are not the most powerful or strict, but in the FP version I managed to have it all very nicely typed, enough for the IDE to preventing me from connecting two chains when their input/output do not match, much better dev ergonomics

  4. The OO version have implicit inputs and outputs, for example, if you are not explicit about the output_parser and explicitly calling predict_and_parse, you have text as output key, other built-in chains have other input and output keys, and it's hard to figure out what they are, because there is no type safety. In the FP type it's all explicit right at the creation at the chain, and type hints of the editor, you don't have to spend time digging into the implementation to find it (forget about docs, it's also not there). To try to use type safety on OO version I had to forcefully cast it on the RoutingChainOutput, since by default it only have very open str lists and dicts as input and outputs

  5. To ease the debug, and get all the benefits of chainlit for example, it's better to have named chains and pass down the callbacks. In the OO, you always need to create a class inheritance to do that, and manually pass down the callbacks, changing the behaviour of a chain is also tiresome by overwritting both the call and acall methods and not forgetting to copy and paste some boilerplate to keep the callbacks. In the FP version, since we don't need to do manual plumbing and can use and_then, the callback forwarding can stay inside the lib, without needing to use a provided composer like SequentialChain or RouterChain, it's trivial to have your own chains in sequence, or fork, without specialized classes, this is because in the FP version we can follow the monad laws

  6. In the FP version I can define everything together when creating the chain, no need to subinstance of PromptTemplate or external class inheritance of RoutingParser for example, we can just use lambdas, this makes the code more co-located and easier to see the whole chain definition at a glance

  7. In the OO version, however, still feel more pythonic and at some points typing the FP version becomes a bit of a struggle. Python doesn't provide good support for typing lambdas, or for having anonymous typed inline dicts (although there are proposals), etc, making it easier to just follow the standard python way to have classes and named methods and so on

  8. Even though langchain by default has this more OO interface, I am not sure (at least for the things I'm doing) if it holds mutable state anywhere (other than very explicit memory for example), so it anyway feels more like it is an OO clothing but without having state in the objects (which is a good thing maybe?)

  9. As a next step, we can erase the differences between LLMs and Agents, having them to compose seamlessly, as mentioned on the screenshot above

More thoughts?

@xingjianll
Copy link

Hi, do you think this is still relevant now that they have LCEL?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment