Skip to content

Instantly share code, notes, and snippets.

@rogeriochaves
Last active May 17, 2024 08:32
Show Gist options
  • Save rogeriochaves/3bb334c57ea5a095c9cf5cc6c53bc499 to your computer and use it in GitHub Desktop.
Save rogeriochaves/3bb334c57ea5a095c9cf5cc6c53bc499 to your computer and use it in GitHub Desktop.
monadic langchain, a more FP interface to langchain
# This is how langchain chain composition looks with a more FP approach
import re
from typing import (
Literal,
TypedDict,
cast,
)
from dotenv import load_dotenv
from monadic import Chain, ConstantChain
load_dotenv()
import langchain
import langchain.schema
from langchain.chat_models import ChatOpenAI
# Setup
langchain.debug = True
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0)
def simple_key_extract(key: str, output: str) -> str:
found = re.search(f"{key}: (.*)", output)
if found is None:
raise Exception("Parsing Error")
return found[1]
# Example
class RoutingChainOutput(TypedDict):
action: Literal["SEARCH", "REPLY"]
param: str
routing_chain = Chain[str, RoutingChainOutput](
"RoutingChain",
llm=llm,
prompt="""
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them.
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \
for everything else, so you can actually do the search and help them.
=============================
Input: hello there
Action: REPLY
Param: hey there, what are you looking for?
Input: how does langchain work?
Action: SEARCH
Param: langchain how it works
Input: code example of vector db
Action: SEARCH
Param: vector db code example
Input: how is it going?
Action: REPLY
Param: I'm going well, how about you?
Input: {question}
""",
input_mapper=lambda input: {"question": input},
output_parser=lambda output: {
"action": cast(
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output)
),
"param": simple_key_extract("Param", output),
},
)
summarizer_chain = Chain[str, str](
"SummarizerChain",
llm=llm,
prompt="Summarize the following text: {text}\nSummary: ",
input_mapper=lambda input: {"text": input},
)
search_chain = Chain[RoutingChainOutput, str](
"SearchChain",
llm=llm,
prompt="Pretend to search for the user. Query: {query}\nResults: ", # this would be replace with a proper vector db search
input_mapper=lambda input: {"query": input["param"]},
).and_then(lambda _: summarizer_chain)
conversation_chain = routing_chain.and_then(
lambda output: ConstantChain(output["param"])
if output["action"] == "REPLY"
else search_chain
)
import chainlit as cl
@cl.langchain_factory(use_async=False)
def factory():
return conversation_chain
# For comparison, this is how the same functionality looks with default langchain interface
import re
from typing import Any, Dict, Literal, TypedDict, cast
from dotenv import load_dotenv
load_dotenv()
import langchain
import langchain.schema
from langchain.chat_models import ChatOpenAI
from langchain import (
LLMChain,
PromptTemplate,
)
from langchain.schema import BaseOutputParser
from langchain.callbacks.manager import (
Callbacks,
)
import chainlit as cl
# Setup
langchain.debug = True
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0)
def simple_key_extract(key: str, output: str) -> str:
found = re.search(f"{key}: (.*)", output)
if found is None:
raise Exception("Parsing Error")
return found[1]
# Example
class RoutingChainOutput(TypedDict):
action: Literal["SEARCH", "REPLY"]
param: str
class RoutingParser(BaseOutputParser[RoutingChainOutput]):
def parse(self, output: str) -> Dict[str, Any]:
return {
"action": cast(
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output)
),
"param": simple_key_extract("Param", output),
}
class RoutingChain(LLMChain):
def __init__(self):
return super().__init__(
llm=llm,
prompt=PromptTemplate(
template="""
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them.
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \
for everything else, so you can actually do the search and help them.
=============================
Input: hello there
Action: REPLY
Param: hey there, what are you looking for?
Input: how does langchain work?
Action: SEARCH
Param: langchain how it works
Input: code example of vector db
Action: SEARCH
Param: vector db code example
Input: how is it going?
Action: REPLY
Param: I'm going well, how about you?
Input: {question}
""",
input_variables=["question"],
output_parser=RoutingParser(),
),
)
class SummarizerChain(LLMChain):
def __init__(self):
return super().__init__(
llm=llm,
prompt=PromptTemplate(
template="Summarize the following text: {text}\nSummary:",
input_variables=["text"],
),
)
class SearchChain(LLMChain):
def __init__(self):
return super().__init__(
llm=llm,
prompt=PromptTemplate(
template="Pretend to search for the user. Query: {query}\nResults: ",
input_variables=["query"],
),
)
def conversation(input: str, callbacks: Callbacks) -> str:
route = cast(
RoutingChainOutput,
RoutingChain().predict_and_parse(callbacks=callbacks, question=input),
)
if route["action"] == "REPLY":
return route["param"]
elif route["action"] == "SEARCH":
result = SearchChain().__call__({"query": route["param"]}, callbacks=callbacks)
result = SummarizerChain().__call__(
{"text": result["text"]}, callbacks=callbacks
)
return result["text"]
else:
return f"unknown action {route['action']}"
@cl.langchain_factory(use_async=False)
def factory():
return conversation
# This is the interface implementation
from typing import (
Any,
Callable,
Dict,
Generic,
Optional,
Type,
TypeVar,
)
import langchain
import langchain.schema
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")
# Monadic Chain
def identity_parser(x: str):
return x
class Chain(Generic[T, U]):
llm: BaseLanguageModel
prompt: str
input_mapper: Optional[Callable[[T], Dict[str, str]]]
output_parser: Callable[[str], U]
named_class: Type[langchain.LLMChain]
def __init__(
self,
name: str,
llm: BaseLanguageModel,
prompt: str,
input_mapper: Optional[Callable[[T], Dict[str, str]]] = None,
output_parser: Callable[[str], U] = identity_parser,
) -> None:
self.llm = llm
self.prompt = prompt
self.input_mapper = input_mapper
self.output_parser = output_parser
self.named_class = type(name, (langchain.LLMChain,), {})
def call(self, input: T, callbacks: Callbacks = []) -> U:
input_values = self.input_mapper(input) if self.input_mapper else input
if type(input_values) is not dict:
raise Exception("cannot extract input_values")
prompt = langchain.PromptTemplate(
template=self.prompt,
input_variables=[str(k) for k in input_values.keys()],
)
chain = self.named_class(llm=self.llm, prompt=prompt, callbacks=callbacks)
result = chain.run(input_values)
result = self.output_parser(result)
return result
def __call__(self, input: T, callbacks: Callbacks = []) -> U:
return self.call(input, callbacks)
def and_then(self, fn: Callable[[U], "Chain[U, V]"]):
return PipeChain(chain_a=self, chain_b_fn=fn)
class IdentityChain(Chain[T, T]):
def __init__(self) -> None:
pass
def call(self, input: T, _callbacks: Callbacks = []):
return input
class ConstantChain(Chain[Any, U]):
output: U
def __init__(self, output: U) -> None:
self.output = output
def call(self, _input: Any, _callbacks: Callbacks = []):
return self.output
class PipeChain(Chain[T, V]):
chain_a: Chain[T, Any]
chain_b_fn: Callable[[Any], Chain[Any, V]]
def __init__(
self, chain_a: Chain[T, U], chain_b_fn: Callable[[U], Chain[Any, V]]
) -> None:
self.chain_a = chain_a
self.chain_b_fn = chain_b_fn
def call(self, input: T, callbacks: Callbacks = []):
output = self.chain_a.call(input, callbacks)
chain_b = self.chain_b_fn(output)
return chain_b.call(output, callbacks)
@xingjianll
Copy link

Hi, do you think this is still relevant now that they have LCEL?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment