Build a tool-calling model with mlflow.pyfunc.ChatModel

Download this Notebook

Welcome to the notebook tutorial on building a simple tool calling model using the mlflow.pyfunc.ChatModel wrapper. ChatModel is a subclass of MLflow’s highly customizable PythonModel, which was specifically designed to make creating GenAI workflows easier.

Briefly, here are some of the benefits of using ChatModel:

  1. No need to define a complex signature! Chat models often accept complex inputs with many levels of nesting, and this can be cumbersome to define yourself.

  2. Support for JSON / dict inputs (no need to wrap inputs or convert to Pandas DataFrame)

  3. Includes the use of Dataclasses for defining expected inputs / outputs for a simplified development experience

For a more in-depth exploration of ChatModel, please check out the detailed guide.

In this tutorial, we’ll be building a simple OpenAI wrapper that makes use of the tool calling support (released in MLflow 2.17.0).

Environment setup

First, let’s set up the environment. We’ll need the OpenAI Python SDK, as well as MLflow >= 2.17.0. We’ll also need to set our OpenAI API key in order to use the SDK.

[16]:
%pip install 'mlflow>=2.17.0' 'openai>=1.0' -qq
Note: you may need to restart the kernel to use updated packages.
[1]:
import os
from getpass import getpass

os.environ["OPENAI_API_KEY"] = getpass("Enter your OpenAI API key: ")

Step 1: Creating the tool definition

Let’s begin to define our model! As mentioned in the introduction, we’ll be subclassing mlflow.pyfunc.ChatModel. For this example, we’ll build a toy model that uses a tool to retrieve the weather for a given city.

The first step is to create a tool definition that we can pass to OpenAI. We do this by using mlflow.types.llm.FunctionToolDefinition to describe the parameters that our tool accepts. The format of this dataclass is aligned with the OpenAI spec:

[2]:
import mlflow
from mlflow.types.llm import (
    FunctionToolDefinition,
    ParamProperty,
    ToolParamsSchema,
)


class WeatherModel(mlflow.pyfunc.ChatModel):
    def __init__(self):
        # a sample tool definition. we use the `FunctionToolDefinition`
        # class to describe the name and expected params for the tool.
        # for this example, we're defining a simple tool that returns
        # the weather for a given city.
        weather_tool = FunctionToolDefinition(
            name="get_weather",
            description="Get weather information",
            parameters=ToolParamsSchema(
                {
                    "city": ParamProperty(
                        type="string",
                        description="City name to get weather information for",
                    ),
                }
            ),
            # make sure to call `to_tool_definition()` to convert the `FunctionToolDefinition`
            # to a `ToolDefinition` object. this step is necessary to normalize the data format,
            # as multiple types of tools (besides just functions) might be available in the future.
        ).to_tool_definition()

        # OpenAI expects tools to be provided as a list of dictionaries
        self.tools = [weather_tool.to_dict()]

Step 2: Implementing the tool

Now that we have a definition for the tool, we need to actually implement it. For the purposes of this tutorial, we’re just going to mock a response, but the implementation can be arbitrary—you might make an API call to an actual weather service, for example.

[3]:
class WeatherModel(mlflow.pyfunc.ChatModel):
    def __init__(self):
        weather_tool = FunctionToolDefinition(
            name="get_weather",
            description="Get weather information",
            parameters=ToolParamsSchema(
                {
                    "city": ParamProperty(
                        type="string",
                        description="City name to get weather information for",
                    ),
                }
            ),
        ).to_tool_definition()

        self.tools = [weather_tool.to_dict()]

        def get_weather(self, city: str) -> str:
            # in a real-world scenario, the implementation might be more complex
            return f"It's sunny in {city}, with a temperature of 20C"

Step 3: Implementing the predict method

The next thing we need to do is define a predict() function that accepts the following arguments:

  1. context: PythonModelContext (not used in this tutorial)

  2. messages: List[ChatMessage]. This is the chat input that the model uses for generation.

  3. params: ChatParams. These are commonly used params used to configure the chat model, e.g. temperature, max_tokens, etc. This is where the tool specifications can be found.

This is the function that will ultimately be called during inference.

For the implementation, we’ll simply forward the user’s input to OpenAI, and provide the get_weather tool as an option for the LLM to use if it chooses to do so. If we receive a tool call request, we’ll call the get_weather() function and return the response back to OpenAI. We’ll need to use what we’ve defined in the previous two steps in order to do this.

[4]:
import json

from openai import OpenAI

import mlflow
from mlflow.types.llm import (
    ChatMessage,
    ChatParams,
    ChatResponse,
)


class WeatherModel(mlflow.pyfunc.ChatModel):
    def __init__(self):
        weather_tool = FunctionToolDefinition(
            name="get_weather",
            description="Get weather information",
            parameters=ToolParamsSchema(
                {
                    "city": ParamProperty(
                        type="string",
                        description="City name to get weather information for",
                    ),
                }
            ),
        ).to_tool_definition()

        self.tools = [weather_tool.to_dict()]

    def get_weather(self, city: str) -> str:
        return "It's sunny in {}, with a temperature of 20C".format(city)

    # the core method that needs to be implemented. this function
    # will be called every time a user sends messages to our model
    def predict(self, context, messages: list[ChatMessage], params: ChatParams):
        # instantiate the OpenAI client
        client = OpenAI()

        # convert the messages to a format that the OpenAI API expects
        messages = [m.to_dict() for m in messages]

        # call the OpenAI API
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages,
            # pass the tools in the request
            tools=self.tools,
        )

        # if OpenAI returns a tool_calling response, then we call
        # our tool. otherwise, we just return the response as is
        tool_calls = response.choices[0].message.tool_calls
        if tool_calls:
            print("Received a tool call, calling the weather tool...")

            # for this example, we only provide the model with one tool,
            # so we can assume the tool call is for the weather tool. if
            # we had more, we'd need to check the name of the tool that
            # was called
            city = json.loads(tool_calls[0].function.arguments)["city"]
            tool_call_id = tool_calls[0].id

            # call the tool and construct a new chat message
            tool_response = ChatMessage(
                role="tool", content=self.get_weather(city), tool_call_id=tool_call_id
            ).to_dict()

            # send another request to the API, making sure to append
            # the assistant's tool call along with the tool response.
            messages.append(response.choices[0].message)
            messages.append(tool_response)
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages,
                tools=self.tools,
            )

        # return the result as a ChatResponse, as this
        # is the expected output of the predict method
        return ChatResponse.from_dict(response.to_dict())

Step 4 (optional, but recommended): Enable tracing for the model

This step is optional, but highly recommended to improve observability in your app. We’ll be using MLflow Tracing to log the inputs and outputs of our model’s internal functions, so we can easily debug when things go wrong. Agent-style tool calling models can make many layers of function calls during the lifespan of a single request, so tracing is invaluable in helping us understand what’s going on at each step.

Integrating tracing is easy, we simply decorate the functions we’re interested in (get_weather() and predict()) with @mlflow.trace! MLflow Tracing also has integrations with many popular GenAI frameworks, such as LangChain, OpenAI, LlamaIndex, and more. For the full list, check out this documentation page. In this tutorial, we’re using the OpenAI SDK to make API calls, so we can enable tracing for this by calling mlflow.openai.autolog().

To view the traces in the UI, run mlflow ui in a separate terminal shell, and navigate to the Traces tab after using the model for inference below.

[5]:
from mlflow.entities.span import (
    SpanType,
)

# automatically trace OpenAI SDK calls
mlflow.openai.autolog()


class WeatherModel(mlflow.pyfunc.ChatModel):
    def __init__(self):
        weather_tool = FunctionToolDefinition(
            name="get_weather",
            description="Get weather information",
            parameters=ToolParamsSchema(
                {
                    "city": ParamProperty(
                        type="string",
                        description="City name to get weather information for",
                    ),
                }
            ),
        ).to_tool_definition()

        self.tools = [weather_tool.to_dict()]

    @mlflow.trace(span_type=SpanType.TOOL)
    def get_weather(self, city: str) -> str:
        return "It's sunny in {}, with a temperature of 20C".format(city)

    @mlflow.trace(span_type=SpanType.AGENT)
    def predict(self, context, messages: list[ChatMessage], params: ChatParams):
        client = OpenAI()

        messages = [m.to_dict() for m in messages]

        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=messages,
            tools=self.tools,
        )

        tool_calls = response.choices[0].message.tool_calls
        if tool_calls:
            print("Received a tool call, calling the weather tool...")

            city = json.loads(tool_calls[0].function.arguments)["city"]
            tool_call_id = tool_calls[0].id

            tool_response = ChatMessage(
                role="tool", content=self.get_weather(city), tool_call_id=tool_call_id
            ).to_dict()

            messages.append(response.choices[0].message)
            messages.append(tool_response)
            response = client.chat.completions.create(
                model="gpt-4o-mini",
                messages=messages,
                tools=self.tools,
            )

        return ChatResponse.from_dict(response.to_dict())

Step 5: Logging the model

Finally, we need to log the model. This saves the model as an artifact in MLflow Tracking, and allows us to load and serve it later on.

(Note: this is a fundamental pattern in MLflow. To learn more, check out the Quickstart guide!)

In order to do this, we need to do a few things:

  1. Define an input example to inform users about the input we expect

  2. Instantiate the model

  3. Call mlflow.pyfunc.log_model() with the above as arguments

Take note of the Model URI printed out at the end of the cell—we’ll need it when serving the model later!

[6]:
# messages to use as input examples
messages = [
    {"role": "system", "content": "Please use the provided tools to answer user queries."},
    {"role": "user", "content": "What's the weather in Singapore?"},
]

input_example = {
    "messages": messages,
}

# instantiate the model
model = WeatherModel()

# log the model
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="weather-model",
        python_model=model,
        input_example=input_example,
    )

    print("Successfully logged the model at the following URI: ", model_info.model_uri)
2024/10/29 09:30:14 INFO mlflow.pyfunc: Predicting on input example to validate output
Received a tool call, calling the weather tool...
Received a tool call, calling the weather tool...
Successfully logged the model at the following URI:  runs:/8051850efa194a3b8b2450c4c9f4d42f/weather-model

Using the model for inference

Now that the model is logged, our work is more or less done! In order to use the model for inference, let’s load it back using mlflow.pyfunc.load_model().

[7]:
import mlflow

# Load the previously logged ChatModel
tool_model = mlflow.pyfunc.load_model(model_info.model_uri)

system_prompt = {
    "role": "system",
    "content": "Please use the provided tools to answer user queries.",
}

messages = [
    system_prompt,
    {"role": "user", "content": "What's the weather in Singapore?"},
]

# Call the model's predict method
response = tool_model.predict({"messages": messages})
print(response["choices"][0]["message"]["content"])

messages = [
    system_prompt,
    {"role": "user", "content": "What's the weather in San Francisco?"},
]

# Generating another response
response = tool_model.predict({"messages": messages})
print(response["choices"][0]["message"]["content"])
2024/10/29 09:30:27 WARNING mlflow.tracing.processor.mlflow: Creating a trace within the default experiment with id '0'. It is strongly recommended to not use the default experiment to log traces due to ambiguous search results and probable performance issues over time due to directory table listing performance degradation with high volumes of directories within a specific path. To avoid performance and disambiguation issues, set the experiment for your environment using `mlflow.set_experiment()` API.
Received a tool call, calling the weather tool...
The weather in Singapore is sunny, with a temperature of 20°C.
Received a tool call, calling the weather tool...
The weather in San Francisco is sunny, with a temperature of 20°C.

Serving the model

MLflow also allows you to serve models, using the mlflow models serve CLI tool. In another terminal shell, run the following from the same folder as this notebook:

$ export OPENAI_API_KEY=<YOUR OPENAI API KEY>
$ mlflow models serve -m <MODEL_URI>

This will start serving the model on http://127.0.0.1:5000, and the model can be queried via POST request to the /invocations route.

[8]:
import requests

messages = [
    system_prompt,
    {"role": "user", "content": "What's the weather in Tokyo?"},
]

response = requests.post("http://127.0.0.1:5000/invocations", json={"messages": messages})
response.raise_for_status()
response.json()
[8]:
{'choices': [{'index': 0,
   'message': {'role': 'assistant',
    'content': 'The weather in Tokyo is sunny, with a temperature of 20°C.'},
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 100, 'completion_tokens': 16, 'total_tokens': 116},
 'id': 'chatcmpl-ANVOhWssEiyYNFwrBPxp1gmQvZKsy',
 'model': 'gpt-4o-mini-2024-07-18',
 'object': 'chat.completion',
 'created': 1730165599}

Conclusion

In this tutorial, we covered how to use MLflow’s ChatModel class to create a convenient OpenAI wrapper that supports tool calling. Though the use-case was simple, the concepts covered here can be easily extended to support more complex functionality.

If you’re looking to dive deeper into building quality GenAI apps, you might be also be interested in checking out MLflow Tracing, an observability tool you can use to trace the execution of arbitrary functions (such as your tool calls, for example).