Skip to content

Instantly share code, notes, and snippets.

@veenaramesh
Created September 3, 2025 15:11
Show Gist options
  • Select an option

  • Save veenaramesh/2c91cc8b8d0f7c688b4f238c5526b12c to your computer and use it in GitHub Desktop.

Select an option

Save veenaramesh/2c91cc8b8d0f7c688b4f238c5526b12c to your computer and use it in GitHub Desktop.
[blog] Agents are like onions (they have layers)
import mlflow
from mlflow.genai.scorers import scorer
from mlflow.genai.scorers import Correctness, RelevanceToQuery
from custom_scorer import ToolUsageScorer
import re
# Workaround for serverless compatibility
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"
model = mlflow.pyfunc.load_model(f"models:/{model_name}@agent_latest")
def evaluate_model(question):
return model.predict({"messages": [{"role": "user", "content": question}]})
test_eval_set = data = [
{
"inputs": {
"question": "What is MLflow? Please answer in spanish. ",
},
"expectations": {
"expected_response": "MLFlow is a open source library."
}
}
]
eval_results = mlflow.genai.evaluate(
data=test_eval_set,
predict_fn=evaluate_model,
scorers=[ToolUsageScorer()]
)
from typing import Any, Generator, Optional, Sequence, Union, Literal
from datetime import datetime
import mlflow
from databricks_langchain import (
ChatDatabricks,
UCFunctionToolkit,
VectorSearchRetrieverTool,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool, tool
from langgraph.graph import START, END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
ChatAgentChunk,
ChatAgentMessage,
ChatAgentResponse,
ChatContext,
)
config_file = "ModelConfig.yml"
model_config = mlflow.models.ModelConfig(development_config=config_file)
uc_catalog = model_config.get("catalog")
schema = model_config.get("schema")
python_execution_function_name = f"{uc_catalog}.{schema}.execute_python_code"
ask_ai_function_name = f"{uc_catalog}.{schema}.ask_ai"
summarization_function_name = f"{uc_catalog}.{schema}.summarize"
translate_function_name = f"{uc_catalog}.{schema}.translate"
@tool
def retrieve_function(query: str) -> str:
"""Retrieve from Databricks Vector Search using the query."""
index = f"{uc_catalog}.{schema}.databricks_documentation_vs_index"
vs_tool = VectorSearchRetrieverTool(
index_name=index,
tool_name="vector_search_retriever",
tool_description="Retrieves information from Databricks Vector Search.",
embedding_model_name=model_config.get("vector_search_config").get("embedding_model_name"),
num_results=model_config.get("vector_search_config").get("num_results"),
columns=model_config.get("vector_search_config").get("columns"),
query_type=model_config.get("vector_search_config").get("query_type")
)
response = vs_tool.invoke(query)
return f"{response[0].metadata['url']} \n{response[0].page_content}"
toolkit = UCFunctionToolkit(
function_names=[
python_execution_function_name,
# ask_ai_function_name, # commenting out to showcase retriever
summarization_function_name,
translate_function_name,
]
)
uc_tools = toolkit.tools
tools = uc_tools + [retrieve_function]
model = ChatDatabricks(endpoint=model_config.get("llm_config").get("endpoint"),
temperature=model_config.get("llm_config").get("temperature"),
max_tokens=model_config.get("llm_config").get("max_tokens")
)
system_prompt = model_config.get("system_prompt")
def create_tool_calling_agent(
model: LanguageModelLike,
tools: Union[ToolNode, Sequence[BaseTool]],
system_prompt: Optional[str]=None
) -> CompiledGraph:
model = model.bind_tools(tools)
def should_continue(state: ChatAgentState) -> Literal["tools", END]:
messages = state["messages"]
last_message = messages[-1]
if last_message.get("tool_calls"):
return "tools"
return END
preprocessor = RunnableLambda(lambda state: [{"role": "system", "content": system_prompt}] + state["messages"])
model_runnable = preprocessor | model
def call_model(state: ChatAgentState, config: RunnableConfig):
failing = True
retry = 10
while failing and retry>=0:
try:
response = model_runnable.invoke(state, config)
failing = False
except Exception as e:
last_error = e
retry -= 1
if failing:
raise last_error
return {"messages": [response]}
workflow = StateGraph(ChatAgentState)
tool_node = ChatAgentToolNode(tools)
workflow.add_node("agent", RunnableLambda(call_model))
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue)
workflow.add_edge("tools", "agent")
return workflow.compile()
class LangGraphChatAgent(ChatAgent):
def __init__(self, agent: CompiledStateGraph):
self.agent = agent
def predict(
self,
messages: list[ChatAgentMessage],
context: Optional[ChatContext] = None,
custom_inputs: Optional[dict[str, Any]] = None,
) -> ChatAgentResponse:
request = {"messages": self._convert_messages_to_dict(messages)}
messages = []
for event in self.agent.stream(request, stream_mode="updates"):
for node_data in event.values():
messages.extend(
ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
)
return ChatAgentResponse(messages=messages)
def predict_stream(
self,
messages: list[ChatAgentMessage],
context: Optional[ChatContext] = None,
custom_inputs: Optional[dict[str, Any]] = None,
) -> Generator[ChatAgentChunk, None, None]:
request = {"messages": self._convert_messages_to_dict(messages)}
for event in self.agent.stream(request, stream_mode="updates"):
for node_data in event.values():
yield from (
ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
)
# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
mlflow.langchain.autolog()
agent = create_tool_calling_agent(model, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment