Skip to content

Instantly share code, notes, and snippets.

@snopoke
Created March 20, 2024 15:24
Show Gist options
  • Save snopoke/f88ed941c92bf04150a9e64d3da8c4d3 to your computer and use it in GitHub Desktop.
Save snopoke/f88ed941c92bf04150a9e64d3da8c4d3 to your computer and use it in GitHub Desktop.
Langchain runnable split / merge POC
"""
Demonstrate a method to create a pipeline that can handle dynamic splits in the pipeline based on the input type.
"""
import functools
import operator
from typing import Any
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda, RunnableSerializable, patch_config
from langchain_core.runnables.base import RunnableEach
from langchain_core.runnables.utils import Input, Output
verbose = []
def get_input_list(item):
out = list(range(item))
verbose and print(f" get_input_list: in: {item}, out: {out}")
return out
def get_input_string(item):
out = str(item) * 2
verbose and print(f" get_input_string: in: {item}, out: {out}")
return out
def process(item):
out = item * 2
verbose and print(f" process: in: {item}, out: {out}")
return out
def process_merge(item):
out = "-".join([str(i) for i in item])
verbose and print(f" process_merge: in: {item}, out: {out}")
return out
class FlexibleRunnableEach(RunnableEach):
"""Runnable that will split the pipeline if the input is a list. Otherwise, it will invoke the bound runnable."""
# This shouldn't really extend `RunnableEach`, this is just for demonstration purposes
# It should rather be a standalone runnable extending RunnableSerializable
def _invoke(
self,
inputs: Input | list[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> list[Output]:
if isinstance(inputs, list):
return self.batch(inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs)
return self.bound.invoke(inputs, patch_config(config, callbacks=run_manager.get_child()))
class MergeRunnable(RunnableSerializable[Input, Output]):
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
def _invoke(
self,
inputs: Input | list[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> list[Output]:
if not isinstance(inputs, list):
inputs = [inputs]
return self.bound.invoke(inputs, patch_config(config, callbacks=run_manager.get_child()))
def invoke(self, inputs: Input | list[Input], config: RunnableConfig | None = None, **kwargs: Any) -> list[Output]:
return self._call_with_config(self._invoke, inputs, config, **kwargs)
def main():
# set_debug(True)
pipes = [
[RunnableLambda(get_input_list), RunnableLambda(get_input_list), RunnableLambda(process)],
[
RunnableLambda(get_input_list),
MergeRunnable(bound=RunnableLambda(process_merge)),
RunnableLambda(get_input_string),
],
[RunnableLambda(get_input_string), RunnableLambda(get_input_string), RunnableLambda(process)],
[
RunnableLambda(get_input_string),
RunnableLambda(get_input_string),
MergeRunnable(bound=RunnableLambda(process_merge)),
],
]
for pipe in pipes:
run_pipe(pipe)
def run_pipe(steps):
# Wrap each step in a runnable that will split the pipeline if the input is a list
# This could be made more explicit if we know that a step is expected to return a list or not
def _wrap_step(step: RunnableSerializable):
if isinstance(step, MergeRunnable):
return step
return FlexibleRunnableEach(bound=step)
chain = functools.reduce(operator.or_, map(_wrap_step, steps))
print(f"\nRunning chain: {chain}\n================================")
for val in ([5, 6], 4):
output = chain.invoke(val)
print(f" input: '{val}', output: '{output}'")
print(" -----------------------------")
if __name__ == "__main__":
verbose.append(1)
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment