Build a tool-calling model with mlflow.pyfunc.ChatModel
Download this NotebookWelcome to the notebook tutorial on building a simple tool calling model using the mlflow.pyfunc.ChatModel wrapper. ChatModel is a subclass of MLflow’s highly customizable PythonModel, which was specifically designed to make creating GenAI workflows easier.
Briefly, here are some of the benefits of using ChatModel:
No need to define a complex signature! Chat models often accept complex inputs with many levels of nesting, and this can be cumbersome to define yourself.
Support for JSON / dict inputs (no need to wrap inputs or convert to Pandas DataFrame)
Includes the use of Dataclasses for defining expected inputs / outputs for a simplified development experience
For a more in-depth exploration of ChatModel, please check out the detailed guide.
In this tutorial, we’ll be building a simple OpenAI wrapper that makes use of the tool calling support (released in MLflow 2.17.0).
Environment setup
First, let’s set up the environment. We’ll need the OpenAI Python SDK, as well as MLflow >= 2.17.0. We’ll also need to set our OpenAI API key in order to use the SDK.
[16]:
%pip install 'mlflow>=2.17.0' 'openai>=1.0' -qq
Note: you may need to restart the kernel to use updated packages.
[1]:
import os
from getpass import getpass
os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")
Step 1: Creating the tool definition
Let’s begin to define our model! As mentioned in the introduction, we’ll be subclassing mlflow.pyfunc.ChatModel
. For this example, we’ll build a toy model that uses a tool to retrieve the weather for a given city.
The first step is to create a tool definition that we can pass to OpenAI. We do this by using mlflow.types.llm.FunctionToolDefinition to describe the parameters that our tool accepts. The format of this dataclass is aligned with the OpenAI spec:
[2]:
import mlflow
from mlflow.types.llm import (
FunctionToolDefinition,
ParamProperty,
ToolParamsSchema,
)
class WeatherModel(mlflow.pyfunc.ChatModel):
def __init__(self):
# a sample tool definition. we use the `FunctionToolDefinition`
# class to describe the name and expected params for the tool.
# for this example, we're defining a simple tool that returns
# the weather for a given city.
weather_tool = FunctionToolDefinition(
name="get_weather",
description="Get weather information",
parameters=ToolParamsSchema(
{
"city": ParamProperty(
type="string",
description="City name to get weather information for",
),
}
),
# make sure to call `to_tool_definition()` to convert the `FunctionToolDefinition`
# to a `ToolDefinition` object. this step is necessary to normalize the data format,
# as multiple types of tools (besides just functions) might be available in the future.
).to_tool_definition()
# OpenAI expects tools to be provided as a list of dictionaries
self.tools = [weather_tool.to_dict()]
Step 2: Implementing the tool
Now that we have a definition for the tool, we need to actually implement it. For the purposes of this tutorial, we’re just going to mock a response, but the implementation can be arbitrary—you might make an API call to an actual weather service, for example.
[3]:
class WeatherModel(mlflow.pyfunc.ChatModel):
def __init__(self):
weather_tool = FunctionToolDefinition(
name="get_weather",
description="Get weather information",
parameters=ToolParamsSchema(
{
"city": ParamProperty(
type="string",
description="City name to get weather information for",
),
}
),
).to_tool_definition()
self.tools = [weather_tool.to_dict()]
def get_weather(self, city: str) -> str:
# in a real-world scenario, the implementation might be more complex
return f"It's sunny in {city}, with a temperature of 20C"
Step 3: Implementing the predict
method
The next thing we need to do is define a predict()
function that accepts the following arguments:
context
: PythonModelContext (not used in this tutorial)messages
: List[ChatMessage]. This is the chat input that the model uses for generation.params
: ChatParams. These are commonly used params used to configure the chat model, e.g.temperature
,max_tokens
, etc. This is where the tool specifications can be found.
This is the function that will ultimately be called during inference.
For the implementation, we’ll simply forward the user’s input to OpenAI, and provide the get_weather
tool as an option for the LLM to use if it chooses to do so. If we receive a tool call request, we’ll call the get_weather()
function and return the response back to OpenAI. We’ll need to use what we’ve defined in the previous two steps in order to do this.
[4]:
import json
from openai import OpenAI
import mlflow
from mlflow.types.llm import (
ChatMessage,
ChatParams,
ChatResponse,
)
class WeatherModel(mlflow.pyfunc.ChatModel):
def __init__(self):
weather_tool = FunctionToolDefinition(
name="get_weather",
description="Get weather information",
parameters=ToolParamsSchema(
{
"city": ParamProperty(
type="string",
description="City name to get weather information for",
),
}
),
).to_tool_definition()
self.tools = [weather_tool.to_dict()]
def get_weather(self, city: str) -> str:
return "It's sunny in {}, with a temperature of 20C".format(city)
# the core method that needs to be implemented. this function
# will be called every time a user sends messages to our model
def predict(self, context, messages: list[ChatMessage], params: ChatParams):
# instantiate the OpenAI client
client = OpenAI()
# convert the messages to a format that the OpenAI API expects
messages = [m.to_dict() for m in messages]
# call the OpenAI API
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
# pass the tools in the request
tools=self.tools,
)
# if OpenAI returns a tool_calling response, then we call
# our tool. otherwise, we just return the response as is
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
print("Received a tool call, calling the weather tool...")
# for this example, we only provide the model with one tool,
# so we can assume the tool call is for the weather tool. if
# we had more, we'd need to check the name of the tool that
# was called
city = json.loads(tool_calls[0].function.arguments)["city"]
tool_call_id = tool_calls[0].id
# call the tool and construct a new chat message
tool_response = ChatMessage(
role="tool", content=self.get_weather(city), tool_call_id=tool_call_id
).to_dict()
# send another request to the API, making sure to append
# the assistant's tool call along with the tool response.
messages.append(response.choices[0].message)
messages.append(tool_response)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
tools=self.tools,
)
# return the result as a ChatResponse, as this
# is the expected output of the predict method
return ChatResponse.from_dict(response.to_dict())
Step 4 (optional, but recommended): Enable tracing for the model
This step is optional, but highly recommended to improve observability in your app. We’ll be using MLflow Tracing to log the inputs and outputs of our model’s internal functions, so we can easily debug when things go wrong. Agent-style tool calling models can make many layers of function calls during the lifespan of a single request, so tracing is invaluable in helping us understand what’s going on at each step.
Integrating tracing is easy, we simply decorate the functions we’re interested in (get_weather()
and predict()
) with @mlflow.trace
! MLflow Tracing also has integrations with many popular GenAI frameworks, such as LangChain, OpenAI, LlamaIndex, and more. For the full list, check out this documentation page. In this tutorial, we’re using the OpenAI SDK to make API calls, so we can enable tracing for this by
calling mlflow.openai.autolog()
.
To view the traces in the UI, run mlflow ui
in a separate terminal shell, and navigate to the Traces
tab after using the model for inference below.
[5]:
from mlflow.entities.span import (
SpanType,
)
# automatically trace OpenAI SDK calls
mlflow.openai.autolog()
class WeatherModel(mlflow.pyfunc.ChatModel):
def __init__(self):
weather_tool = FunctionToolDefinition(
name="get_weather",
description="Get weather information",
parameters=ToolParamsSchema(
{
"city": ParamProperty(
type="string",
description="City name to get weather information for",
),
}
),
).to_tool_definition()
self.tools = [weather_tool.to_dict()]
@mlflow.trace(span_type=SpanType.TOOL)
def get_weather(self, city: str) -> str:
return "It's sunny in {}, with a temperature of 20C".format(city)
@mlflow.trace(span_type=SpanType.AGENT)
def predict(self, context, messages: list[ChatMessage], params: ChatParams):
client = OpenAI()
messages = [m.to_dict() for m in messages]
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
tools=self.tools,
)
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
print("Received a tool call, calling the weather tool...")
city = json.loads(tool_calls[0].function.arguments)["city"]
tool_call_id = tool_calls[0].id
tool_response = ChatMessage(
role="tool", content=self.get_weather(city), tool_call_id=tool_call_id
).to_dict()
messages.append(response.choices[0].message)
messages.append(tool_response)
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
tools=self.tools,
)
return ChatResponse.from_dict(response.to_dict())
Step 5: Logging the model
Finally, we need to log the model. This saves the model as an artifact in MLflow Tracking, and allows us to load and serve it later on.
(Note: this is a fundamental pattern in MLflow. To learn more, check out the Quickstart guide!)
In order to do this, we need to do a few things:
Define an input example to inform users about the input we expect
Instantiate the model
Call
mlflow.pyfunc.log_model()
with the above as arguments
Take note of the Model URI printed out at the end of the cell—we’ll need it when serving the model later!
[6]:
# messages to use as input examples
messages = [
{"role": "system", "content": "Please use the provided tools to answer user queries."},
{"role": "user", "content": "What's the weather in Singapore?"},
]
input_example = {
"messages": messages,
}
# instantiate the model
model = WeatherModel()
# log the model
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
artifact_path="weather-model",
python_model=model,
input_example=input_example,
)
print("Successfully logged the model at the following URI: ", model_info.model_uri)
2024/10/29 09:30:14 INFO mlflow.pyfunc: Predicting on input example to validate output
Received a tool call, calling the weather tool...
Received a tool call, calling the weather tool...
Successfully logged the model at the following URI: runs:/8051850efa194a3b8b2450c4c9f4d42f/weather-model
Using the model for inference
Now that the model is logged, our work is more or less done! In order to use the model for inference, let’s load it back using mlflow.pyfunc.load_model()
.
[7]:
import mlflow
# Load the previously logged ChatModel
tool_model = mlflow.pyfunc.load_model(model_info.model_uri)
system_prompt = {
"role": "system",
"content": "Please use the provided tools to answer user queries.",
}
messages = [
system_prompt,
{"role": "user", "content": "What's the weather in Singapore?"},
]
# Call the model's predict method
response = tool_model.predict({"messages": messages})
print(response["choices"][0]["message"]["content"])
messages = [
system_prompt,
{"role": "user", "content": "What's the weather in San Francisco?"},
]
# Generating another response
response = tool_model.predict({"messages": messages})
print(response["choices"][0]["message"]["content"])
2024/10/29 09:30:27 WARNING mlflow.tracing.processor.mlflow: Creating a trace within the default experiment with id '0'. It is strongly recommended to not use the default experiment to log traces due to ambiguous search results and probable performance issues over time due to directory table listing performance degradation with high volumes of directories within a specific path. To avoid performance and disambiguation issues, set the experiment for your environment using `mlflow.set_experiment()` API.
Received a tool call, calling the weather tool...
The weather in Singapore is sunny, with a temperature of 20°C.
Received a tool call, calling the weather tool...
The weather in San Francisco is sunny, with a temperature of 20°C.
Serving the model
MLflow also allows you to serve models, using the mlflow models serve
CLI tool. In another terminal shell, run the following from the same folder as this notebook:
$ export OPENAI_API_KEY=<YOUR OPENAI API KEY>
$ mlflow models serve -m <MODEL_URI>
This will start serving the model on http://127.0.0.1:5000
, and the model can be queried via POST request to the /invocations
route.
[8]:
import requests
messages = [
system_prompt,
{"role": "user", "content": "What's the weather in Tokyo?"},
]
response = requests.post("http://127.0.0.1:5000/invocations", json={"messages": messages})
response.raise_for_status()
response.json()
[8]:
{'choices': [{'index': 0,
'message': {'role': 'assistant',
'content': 'The weather in Tokyo is sunny, with a temperature of 20°C.'},
'finish_reason': 'stop'}],
'usage': {'prompt_tokens': 100, 'completion_tokens': 16, 'total_tokens': 116},
'id': 'chatcmpl-ANVOhWssEiyYNFwrBPxp1gmQvZKsy',
'model': 'gpt-4o-mini-2024-07-18',
'object': 'chat.completion',
'created': 1730165599}
Conclusion
In this tutorial, we covered how to use MLflow’s ChatModel
class to create a convenient OpenAI wrapper that supports tool calling. Though the use-case was simple, the concepts covered here can be easily extended to support more complex functionality.
If you’re looking to dive deeper into building quality GenAI apps, you might be also be interested in checking out MLflow Tracing, an observability tool you can use to trace the execution of arbitrary functions (such as your tool calls, for example).