Last active
May 17, 2024 08:32
-
-
Save rogeriochaves/3bb334c57ea5a095c9cf5cc6c53bc499 to your computer and use it in GitHub Desktop.
monadic langchain, a more FP interface to langchain
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is how langchain chain composition looks with a more FP approach | |
import re | |
from typing import ( | |
Literal, | |
TypedDict, | |
cast, | |
) | |
from dotenv import load_dotenv | |
from monadic import Chain, ConstantChain | |
load_dotenv() | |
import langchain | |
import langchain.schema | |
from langchain.chat_models import ChatOpenAI | |
# Setup | |
langchain.debug = True | |
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0) | |
def simple_key_extract(key: str, output: str) -> str: | |
found = re.search(f"{key}: (.*)", output) | |
if found is None: | |
raise Exception("Parsing Error") | |
return found[1] | |
# Example | |
class RoutingChainOutput(TypedDict): | |
action: Literal["SEARCH", "REPLY"] | |
param: str | |
routing_chain = Chain[str, RoutingChainOutput]( | |
"RoutingChain", | |
llm=llm, | |
prompt=""" | |
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them. | |
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \ | |
for everything else, so you can actually do the search and help them. | |
============================= | |
Input: hello there | |
Action: REPLY | |
Param: hey there, what are you looking for? | |
Input: how does langchain work? | |
Action: SEARCH | |
Param: langchain how it works | |
Input: code example of vector db | |
Action: SEARCH | |
Param: vector db code example | |
Input: how is it going? | |
Action: REPLY | |
Param: I'm going well, how about you? | |
Input: {question} | |
""", | |
input_mapper=lambda input: {"question": input}, | |
output_parser=lambda output: { | |
"action": cast( | |
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output) | |
), | |
"param": simple_key_extract("Param", output), | |
}, | |
) | |
summarizer_chain = Chain[str, str]( | |
"SummarizerChain", | |
llm=llm, | |
prompt="Summarize the following text: {text}\nSummary: ", | |
input_mapper=lambda input: {"text": input}, | |
) | |
search_chain = Chain[RoutingChainOutput, str]( | |
"SearchChain", | |
llm=llm, | |
prompt="Pretend to search for the user. Query: {query}\nResults: ", # this would be replace with a proper vector db search | |
input_mapper=lambda input: {"query": input["param"]}, | |
).and_then(lambda _: summarizer_chain) | |
conversation_chain = routing_chain.and_then( | |
lambda output: ConstantChain(output["param"]) | |
if output["action"] == "REPLY" | |
else search_chain | |
) | |
import chainlit as cl | |
@cl.langchain_factory(use_async=False) | |
def factory(): | |
return conversation_chain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# For comparison, this is how the same functionality looks with default langchain interface | |
import re | |
from typing import Any, Dict, Literal, TypedDict, cast | |
from dotenv import load_dotenv | |
load_dotenv() | |
import langchain | |
import langchain.schema | |
from langchain.chat_models import ChatOpenAI | |
from langchain import ( | |
LLMChain, | |
PromptTemplate, | |
) | |
from langchain.schema import BaseOutputParser | |
from langchain.callbacks.manager import ( | |
Callbacks, | |
) | |
import chainlit as cl | |
# Setup | |
langchain.debug = True | |
llm = ChatOpenAI(client=None, model="gpt-3.5-turbo", temperature=0) | |
def simple_key_extract(key: str, output: str) -> str: | |
found = re.search(f"{key}: (.*)", output) | |
if found is None: | |
raise Exception("Parsing Error") | |
return found[1] | |
# Example | |
class RoutingChainOutput(TypedDict): | |
action: Literal["SEARCH", "REPLY"] | |
param: str | |
class RoutingParser(BaseOutputParser[RoutingChainOutput]): | |
def parse(self, output: str) -> Dict[str, Any]: | |
return { | |
"action": cast( | |
Literal["SEARCH", "REPLY"], simple_key_extract("Action", output) | |
), | |
"param": simple_key_extract("Param", output), | |
} | |
class RoutingChain(LLMChain): | |
def __init__(self): | |
return super().__init__( | |
llm=llm, | |
prompt=PromptTemplate( | |
template=""" | |
You are a chatbot that helps users search on the documentation, but you can also do some chit-chatting with them. | |
Choose the action REPLY if the user is just chit-chatting, like greeting, asking how are you, etc, but choose SEARCH \ | |
for everything else, so you can actually do the search and help them. | |
============================= | |
Input: hello there | |
Action: REPLY | |
Param: hey there, what are you looking for? | |
Input: how does langchain work? | |
Action: SEARCH | |
Param: langchain how it works | |
Input: code example of vector db | |
Action: SEARCH | |
Param: vector db code example | |
Input: how is it going? | |
Action: REPLY | |
Param: I'm going well, how about you? | |
Input: {question} | |
""", | |
input_variables=["question"], | |
output_parser=RoutingParser(), | |
), | |
) | |
class SummarizerChain(LLMChain): | |
def __init__(self): | |
return super().__init__( | |
llm=llm, | |
prompt=PromptTemplate( | |
template="Summarize the following text: {text}\nSummary:", | |
input_variables=["text"], | |
), | |
) | |
class SearchChain(LLMChain): | |
def __init__(self): | |
return super().__init__( | |
llm=llm, | |
prompt=PromptTemplate( | |
template="Pretend to search for the user. Query: {query}\nResults: ", | |
input_variables=["query"], | |
), | |
) | |
def conversation(input: str, callbacks: Callbacks) -> str: | |
route = cast( | |
RoutingChainOutput, | |
RoutingChain().predict_and_parse(callbacks=callbacks, question=input), | |
) | |
if route["action"] == "REPLY": | |
return route["param"] | |
elif route["action"] == "SEARCH": | |
result = SearchChain().__call__({"query": route["param"]}, callbacks=callbacks) | |
result = SummarizerChain().__call__( | |
{"text": result["text"]}, callbacks=callbacks | |
) | |
return result["text"] | |
else: | |
return f"unknown action {route['action']}" | |
@cl.langchain_factory(use_async=False) | |
def factory(): | |
return conversation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is the interface implementation | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Generic, | |
Optional, | |
Type, | |
TypeVar, | |
) | |
import langchain | |
import langchain.schema | |
from langchain.base_language import BaseLanguageModel | |
from langchain.callbacks.manager import Callbacks | |
T = TypeVar("T") | |
U = TypeVar("U") | |
V = TypeVar("V") | |
# Monadic Chain | |
def identity_parser(x: str): | |
return x | |
class Chain(Generic[T, U]): | |
llm: BaseLanguageModel | |
prompt: str | |
input_mapper: Optional[Callable[[T], Dict[str, str]]] | |
output_parser: Callable[[str], U] | |
named_class: Type[langchain.LLMChain] | |
def __init__( | |
self, | |
name: str, | |
llm: BaseLanguageModel, | |
prompt: str, | |
input_mapper: Optional[Callable[[T], Dict[str, str]]] = None, | |
output_parser: Callable[[str], U] = identity_parser, | |
) -> None: | |
self.llm = llm | |
self.prompt = prompt | |
self.input_mapper = input_mapper | |
self.output_parser = output_parser | |
self.named_class = type(name, (langchain.LLMChain,), {}) | |
def call(self, input: T, callbacks: Callbacks = []) -> U: | |
input_values = self.input_mapper(input) if self.input_mapper else input | |
if type(input_values) is not dict: | |
raise Exception("cannot extract input_values") | |
prompt = langchain.PromptTemplate( | |
template=self.prompt, | |
input_variables=[str(k) for k in input_values.keys()], | |
) | |
chain = self.named_class(llm=self.llm, prompt=prompt, callbacks=callbacks) | |
result = chain.run(input_values) | |
result = self.output_parser(result) | |
return result | |
def __call__(self, input: T, callbacks: Callbacks = []) -> U: | |
return self.call(input, callbacks) | |
def and_then(self, fn: Callable[[U], "Chain[U, V]"]): | |
return PipeChain(chain_a=self, chain_b_fn=fn) | |
class IdentityChain(Chain[T, T]): | |
def __init__(self) -> None: | |
pass | |
def call(self, input: T, _callbacks: Callbacks = []): | |
return input | |
class ConstantChain(Chain[Any, U]): | |
output: U | |
def __init__(self, output: U) -> None: | |
self.output = output | |
def call(self, _input: Any, _callbacks: Callbacks = []): | |
return self.output | |
class PipeChain(Chain[T, V]): | |
chain_a: Chain[T, Any] | |
chain_b_fn: Callable[[Any], Chain[Any, V]] | |
def __init__( | |
self, chain_a: Chain[T, U], chain_b_fn: Callable[[U], Chain[Any, V]] | |
) -> None: | |
self.chain_a = chain_a | |
self.chain_b_fn = chain_b_fn | |
def call(self, input: T, callbacks: Callbacks = []): | |
output = self.chain_a.call(input, callbacks) | |
chain_b = self.chain_b_fn(output) | |
return chain_b.call(output, callbacks) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, do you think this is still relevant now that they have LCEL?