Created
September 3, 2025 15:11
-
-
Save veenaramesh/2c91cc8b8d0f7c688b4f238c5526b12c to your computer and use it in GitHub Desktop.
[blog] Agents are like onions (they have layers)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import mlflow | |
| from mlflow.genai.scorers import scorer | |
| from mlflow.genai.scorers import Correctness, RelevanceToQuery | |
| from custom_scorer import ToolUsageScorer | |
| import re | |
| # Workaround for serverless compatibility | |
| mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc" | |
| model = mlflow.pyfunc.load_model(f"models:/{model_name}@agent_latest") | |
| def evaluate_model(question): | |
| return model.predict({"messages": [{"role": "user", "content": question}]}) | |
| test_eval_set = data = [ | |
| { | |
| "inputs": { | |
| "question": "What is MLflow? Please answer in spanish. ", | |
| }, | |
| "expectations": { | |
| "expected_response": "MLFlow is a open source library." | |
| } | |
| } | |
| ] | |
| eval_results = mlflow.genai.evaluate( | |
| data=test_eval_set, | |
| predict_fn=evaluate_model, | |
| scorers=[ToolUsageScorer()] | |
| ) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Generator, Optional, Sequence, Union, Literal | |
| from datetime import datetime | |
| import mlflow | |
| from databricks_langchain import ( | |
| ChatDatabricks, | |
| UCFunctionToolkit, | |
| VectorSearchRetrieverTool, | |
| ) | |
| from langchain_core.language_models import LanguageModelLike | |
| from langchain_core.runnables import RunnableConfig, RunnableLambda | |
| from langchain_core.tools import BaseTool, tool | |
| from langgraph.graph import START, END, StateGraph | |
| from langgraph.graph.graph import CompiledGraph | |
| from langgraph.graph.state import CompiledStateGraph | |
| from langgraph.prebuilt.tool_node import ToolNode | |
| from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode | |
| from mlflow.pyfunc import ChatAgent | |
| from mlflow.types.agent import ( | |
| ChatAgentChunk, | |
| ChatAgentMessage, | |
| ChatAgentResponse, | |
| ChatContext, | |
| ) | |
| config_file = "ModelConfig.yml" | |
| model_config = mlflow.models.ModelConfig(development_config=config_file) | |
| uc_catalog = model_config.get("catalog") | |
| schema = model_config.get("schema") | |
| python_execution_function_name = f"{uc_catalog}.{schema}.execute_python_code" | |
| ask_ai_function_name = f"{uc_catalog}.{schema}.ask_ai" | |
| summarization_function_name = f"{uc_catalog}.{schema}.summarize" | |
| translate_function_name = f"{uc_catalog}.{schema}.translate" | |
| @tool | |
| def retrieve_function(query: str) -> str: | |
| """Retrieve from Databricks Vector Search using the query.""" | |
| index = f"{uc_catalog}.{schema}.databricks_documentation_vs_index" | |
| vs_tool = VectorSearchRetrieverTool( | |
| index_name=index, | |
| tool_name="vector_search_retriever", | |
| tool_description="Retrieves information from Databricks Vector Search.", | |
| embedding_model_name=model_config.get("vector_search_config").get("embedding_model_name"), | |
| num_results=model_config.get("vector_search_config").get("num_results"), | |
| columns=model_config.get("vector_search_config").get("columns"), | |
| query_type=model_config.get("vector_search_config").get("query_type") | |
| ) | |
| response = vs_tool.invoke(query) | |
| return f"{response[0].metadata['url']} \n{response[0].page_content}" | |
| toolkit = UCFunctionToolkit( | |
| function_names=[ | |
| python_execution_function_name, | |
| # ask_ai_function_name, # commenting out to showcase retriever | |
| summarization_function_name, | |
| translate_function_name, | |
| ] | |
| ) | |
| uc_tools = toolkit.tools | |
| tools = uc_tools + [retrieve_function] | |
| model = ChatDatabricks(endpoint=model_config.get("llm_config").get("endpoint"), | |
| temperature=model_config.get("llm_config").get("temperature"), | |
| max_tokens=model_config.get("llm_config").get("max_tokens") | |
| ) | |
| system_prompt = model_config.get("system_prompt") | |
| def create_tool_calling_agent( | |
| model: LanguageModelLike, | |
| tools: Union[ToolNode, Sequence[BaseTool]], | |
| system_prompt: Optional[str]=None | |
| ) -> CompiledGraph: | |
| model = model.bind_tools(tools) | |
| def should_continue(state: ChatAgentState) -> Literal["tools", END]: | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| if last_message.get("tool_calls"): | |
| return "tools" | |
| return END | |
| preprocessor = RunnableLambda(lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]) | |
| model_runnable = preprocessor | model | |
| def call_model(state: ChatAgentState, config: RunnableConfig): | |
| failing = True | |
| retry = 10 | |
| while failing and retry>=0: | |
| try: | |
| response = model_runnable.invoke(state, config) | |
| failing = False | |
| except Exception as e: | |
| last_error = e | |
| retry -= 1 | |
| if failing: | |
| raise last_error | |
| return {"messages": [response]} | |
| workflow = StateGraph(ChatAgentState) | |
| tool_node = ChatAgentToolNode(tools) | |
| workflow.add_node("agent", RunnableLambda(call_model)) | |
| workflow.add_node("tools", tool_node) | |
| workflow.add_edge(START, "agent") | |
| workflow.add_conditional_edges("agent", should_continue) | |
| workflow.add_edge("tools", "agent") | |
| return workflow.compile() | |
| class LangGraphChatAgent(ChatAgent): | |
| def __init__(self, agent: CompiledStateGraph): | |
| self.agent = agent | |
| def predict( | |
| self, | |
| messages: list[ChatAgentMessage], | |
| context: Optional[ChatContext] = None, | |
| custom_inputs: Optional[dict[str, Any]] = None, | |
| ) -> ChatAgentResponse: | |
| request = {"messages": self._convert_messages_to_dict(messages)} | |
| messages = [] | |
| for event in self.agent.stream(request, stream_mode="updates"): | |
| for node_data in event.values(): | |
| messages.extend( | |
| ChatAgentMessage(**msg) for msg in node_data.get("messages", []) | |
| ) | |
| return ChatAgentResponse(messages=messages) | |
| def predict_stream( | |
| self, | |
| messages: list[ChatAgentMessage], | |
| context: Optional[ChatContext] = None, | |
| custom_inputs: Optional[dict[str, Any]] = None, | |
| ) -> Generator[ChatAgentChunk, None, None]: | |
| request = {"messages": self._convert_messages_to_dict(messages)} | |
| for event in self.agent.stream(request, stream_mode="updates"): | |
| for node_data in event.values(): | |
| yield from ( | |
| ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"] | |
| ) | |
| # Create the agent object, and specify it as the agent object to use when | |
| # loading the agent back for inference via mlflow.models.set_model() | |
| mlflow.langchain.autolog() | |
| agent = create_tool_calling_agent(model, tools, system_prompt) | |
| AGENT = LangGraphChatAgent(agent) | |
| mlflow.models.set_model(AGENT) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment