Tutorial: Custom GenAI Models using ChatModel

The rapidly evolving landscape of Generative Artificial Intelligence (GenAI) presents exciting opportunities and integration challenges. To leverage the latest GenAI advancements effectively, developers need a framework that balances flexibility with standardization. MLflow addresses this need with the mlflow.pyfunc.ChatModel class introduced in version 2.11.0, providing a consistent interface for GenAI applications while simplifying deployment and testing.

Choosing Between ChatModel and PythonModel

When building GenAI applications in MLflow, it’s essential to choose the right model abstraction that balances ease of use with the level of customization you need. MLflow offers two primary classes for this purpose: mlflow.pyfunc.ChatModel and mlflow.pyfunc.PythonModel. Each has its own strengths and trade-offs, making it crucial to understand which one best suits your use case.

ChatModel

PythonModel

When to use

Use when you want to develop and deploy a conversational model with standard chat schema compatible with OpenAI spec.

Use when you want full control over the model’s interface or customize every aspect of your model’s behavior.

Interface

Fixed to OpenAI’s chat schema.

Full control over the model’s input and output schema.

Setup

Quick. Works out of the box for conversational applications, with pre-defined model signature and input example.

Custom. You need to define model signature or input example yourself.

Complexity

Low. Standardized interface simplified model deployment and integration.

High. Deploying and integrating the custom PythonModel may not be straightforward. E.g., The model needs to handle Pandas DataFrames as MLflow converts input data to DataFrames before passing it to PythonModel.

Purpose of this tutorial

This tutorial will guide you through the process of creating a custom chat agent using MLflow’s mlflow.pyfunc.ChatModel class.

By the end of this tutorial you will:

Prerequisites

  • Familiarity with MLflow logging APIs and GenAI concepts.

  • MLflow version 2.11.0 or higher installed for use of mlflow.pyfunc.ChatModel.

  • MLflow version 2.14.0 or higher installed for use of MLflow Tracing.

This tutorial uses the Databricks Foundation Model APIs purely as an example of interfacing with an external service. You can easily swap the provider example to use any managed LLM hosting service with ease (Amazon Bedrock, Azure AI Studio, OpenAI, Anthropic, and many others).

Core Concepts

Tracing Customization for GenAI


MLflow Tracing allows you to monitor and log the execution of your model’s methods, providing valuable insights during debugging and performance optimization.

In our example BasicAgent implementation we utilize two separate APIs for the initiation of trace spans: the decorator API and the fluent API.

Decorator API

@mlflow.trace
def _get_system_message(self, role: str) -> Dict:
    if role not in self.models:
        raise ValueError(f"Unknown role: {role}")

    instruction = self.models[role]["instruction"]
    return ChatMessage(role="system", content=instruction).to_dict()

Using the @mlflow.trace tracing decorator is the simplest way to add tracing functionality to functions and methods. By default, a span that is generated from the application of this decorator will utilize the name of the function as the name of the span. It is possible to override this naming, as well as other parameters associated with the span, as follows:

@mlflow.trace(name="custom_span_name", attributes={"key": "value"}, span_type="func")
def _get_system_message(self, role: str) -> Dict:
    if role not in self.models:
        raise ValueError(f"Unknown role: {role}")

    instruction = self.models[role]["instruction"]
    return ChatMessage(role="system", content=instruction).to_dict()

Tip

It is always advised to set a human-readable name for any span that you generate, particularly if you are instrumenting private or generically named functions or methods. The MLflow Trace UI will display the name of the function or method by default, which can be confusing to follow if your functions and methods are ambiguously named.

Fluent API

The fluent APIs context handler implementation for initiating spans is useful when you need full control of the logging of each aspect of the span’s data.

The example from our application for ensuring that we’re capturing the parameters that are set when loading the model via the load_context method is shown below. We are pulling from the instance attributes self.models_config and self.models to set the attributes of the span.

with mlflow.start_span("Audit Agent") as root_span:
    root_span.set_inputs(messages)
    attributes = {**params.to_dict(), **self.models_config, **self.models}
    root_span.set_attributes(attributes)
    # More span manipulation...

Traces in the MLflow UI

After running our example that includes these combined usage patterns for trace span generation and instrumentation,

Traces in the MLflow UI for the Agent example

Key Classes and Methods in our example

  • BasicAgent: Our custom chat agent class that extends ChatModel.

  • _get_system_message: Retrieves the system message configuration for a specific role.

  • _get_agent_response`: Sends messages to an endpoint and retrieves responses.

  • _call_agent: Manages the conversation flow between the agent roles.

  • _prepare_message_list`: Prepares the list of messages for sending.

  • load_context: Initializes the model context and configurations.

  • predict`: Handles the prediction logic for the chat model.

Of these methods listed above, the methods load_context and predict override the base abstracted implementations of ChatModel. In order to define a subclass of ChatModel, you must implement (at a minimum), the predict method. The load_context method is only used if you are implementing (as we will be below) custom loading logic where a static configuration needs to be loaded for the model object to work, or additional dependent logic needs to execute in order for the object instantiation to function correctly.

Example of a custom ChatModel

In the full example below, we’re creating a custom chat agent by subclassing the mlflow.pyfunc.ChatModel. This agent, named BasicAgent, takes advantage of several important features that help streamline the development, deployment, and tracking of GenAI applications. By subclassing ChatModel, we ensure a consistent interface for handling conversational agents, while also avoiding common pitfalls associated with more general-purpose models.

The implementation below highlights the following key aspects:

  • Tracing: We leverage MLflow’s tracing functionality to track and log critical operations using both the decorator and fluent API context handler approaches.

    • Decorator API: This is used to easily trace methods such as _get_agent_response and _call_agent for automatic span creation.

    • Fluent API: Provides fine-grained control over span creation, as shown in the predict method for auditing key inputs and outputs during agent interactions.

    • Tip: We ensure human-readable span names for easier debugging in the MLflow Trace UI and when fetching logged traces via the client API.

  • Custom Configuration:

    • Model Configuration: By passing custom configurations during model logging (using the model_config parameter), we decouple model behavior from hard-coded values. This allows rapid testing of different agent configurations without modifying the source code.

    • load_context Method: Ensures that configurations are loaded at runtime, initializing the agent with the necessary settings and preventing runtime failures due to missing configurations.

    • Tip: We avoid directly setting undefined instance attributes within load_context. Instead, all attributes are initialized with default values in the class constructor to ensure proper loading of our model.

  • Conversation Management:

    • We implement a multi-step agent interaction pattern using methods like _get_system_message, _get_agent_response, and _call_agent. These methods manage the flow of communication between multiple agents, such as an “oracle” and a “judge” role, each configured with specific instructions and parameters.

    • Static Input/Output Structures: By adhering to the ChatModel’s required input (List[ChatMessage]) and output (ChatCompletionResponse) formats, we eliminate the complexities associated with converting JSON or tabular data, which is common in more general models like PythonModel.

  • Common Pitfalls Avoided:

    • Model Validation via Input Examples: We provide an input example during model logging, allowing MLflow to validate the input interface and catch structural issues early, reducing debugging time during deployment.

import mlflow
from mlflow.types.llm import ChatCompletionResponse, ChatMessage, ChatParams, ChatChoice
from mlflow.pyfunc import ChatModel
from mlflow import deployments
from typing import List, Optional, Dict


class BasicAgent(ChatModel):
    def __init__(self):
        """Initialize the BasicAgent with placeholder values."""
        self.deploy_client = None
        self.models = {}
        self.models_config = {}
        self.conversation_history = []

    def load_context(self, context):
        """Initialize the connectors and model configurations."""
        self.deploy_client = deployments.get_deploy_client("databricks")
        self.models = context.model_config.get("models", {})
        self.models_config = context.model_config.get("configuration", {})

    def _get_system_message(self, role: str) -> Dict:
        """
        Get the system message configuration for the specified role.

        Args:
            role (str): The role of the agent (e.g., "oracle" or "judge").

        Returns:
            dict: The system message for the given role.
        """
        if role not in self.models:
            raise ValueError(f"Unknown role: {role}")

        instruction = self.models[role]["instruction"]
        return ChatMessage(role="system", content=instruction).to_dict()

    @mlflow.trace(name="Raw Agent Response")
    def _get_agent_response(
        self, message_list: List[Dict], endpoint: str, params: Optional[dict] = None
    ) -> Dict:
        """
        Call the agent endpoint to get a response.

        Args:
            message_list (List[Dict]): List of messages for the agent.
            endpoint (str): The agent's endpoint.
            params (Optional[dict]): Additional parameters for the call.

        Returns:
            dict: The response from the agent.
        """
        response = self.deploy_client.predict(
            endpoint=endpoint, inputs={"messages": message_list, **(params or {})}
        )
        return response["choices"][0]["message"]

    @mlflow.trace(name="Agent Call")
    def _call_agent(
        self, message: ChatMessage, role: str, params: Optional[dict] = None
    ) -> Dict:
        """
        Prepares and sends the request to a specific agent based on the role.

        Args:
            message (ChatMessage): The message to be processed.
            role (str): The role of the agent (e.g., "oracle" or "judge").
            params (Optional[dict]): Additional parameters for the call.

        Returns:
            dict: The response from the agent.
        """
        system_message = self._get_system_message(role)
        message_list = self._prepare_message_list(system_message, message)

        # Fetch agent response
        agent_config = self.models[role]
        response = self._get_agent_response(
            message_list, agent_config["endpoint"], params
        )

        # Update conversation history
        self.conversation_history.extend([message.to_dict(), response])
        return response

    @mlflow.trace(name="Assemble Conversation")
    def _prepare_message_list(
        self, system_message: Dict, user_message: ChatMessage
    ) -> List[Dict]:
        """
        Prepare the list of messages to send to the agent.

        Args:
            system_message (dict): The system message dictionary.
            user_message (ChatMessage): The user message.

        Returns:
            List[dict]: The complete list of messages to send.
        """
        user_prompt = {
            "role": "user",
            "content": self.models_config.get(
                "user_response_instruction", "Can you make the answer better?"
            ),
        }
        if self.conversation_history:
            return [system_message, *self.conversation_history, user_prompt]
        else:
            return [system_message, user_message.to_dict()]

    def predict(
        self, context, messages: List[ChatMessage], params: Optional[ChatParams] = None
    ) -> ChatCompletionResponse:
        """
        Predict method to handle agent conversation.

        Args:
            context: The MLflow context.
            messages (List[ChatMessage]): List of messages to process.
            params (Optional[ChatParams]): Additional parameters for the conversation.

        Returns:
            ChatCompletionResponse: The structured response object.
        """
        # Use the fluent API context handler to have added control over what is included in the span
        with mlflow.start_span(name="Audit Agent") as root_span:
            # Add the user input to the root span
            root_span.set_inputs(messages)

            # Add attributes to the root span
            attributes = {**params.to_dict(), **self.models_config, **self.models}
            root_span.set_attributes(attributes)

            # Initiate the conversation with the oracle
            oracle_params = self._get_model_params("oracle")
            oracle_response = self._call_agent(messages[0], "oracle", oracle_params)

            # Process the response with the judge
            judge_params = self._get_model_params("judge")
            judge_response = self._call_agent(
                ChatMessage(**oracle_response), "judge", judge_params
            )

            # Reset the conversation history and return the final response
            self.conversation_history = []

            output = ChatCompletionResponse(
                choices=[ChatChoice(index=0, message=ChatMessage(**judge_response))],
                usage={},
                model=judge_params.get("endpoint", "unknown"),
            )

            root_span.set_outputs(output)

        return output

    def _get_model_params(self, role: str) -> dict:
        """
        Retrieves model parameters for a given role.

        Args:
            role (str): The role of the agent (e.g., "oracle" or "judge").

        Returns:
            dict: A dictionary of parameters for the agent.
        """
        role_config = self.models.get(role, {})

        return {
            "temperature": role_config.get("temperature", 0.5),
            "max_tokens": role_config.get("max_tokens", 500),
        }

Now that we have our model defined, the process of logging it has only a single step that is required to be taken before logging: we need to define the configuration for our model to be initialized with. This is done by defining our model_config configuration.

Setting our model_config values

Before logging the model, we need to define the configuration that governs the behavior of our model’s agents. This decoupling of configuration from the core logic of the model allows us to easily test and compare different agent behaviors without needing to modify the model implementation. By using a flexible configuration system, we can efficiently experiment with different settings, making it much easier to iterate and fine-tune our model.

Why Decouple Configuration?

In the context of Generative AI (GenAI), agent behavior can vary greatly depending on the instruction sets and parameters (such as temperature or max_tokens) given to each agent. If we hardcoded these configurations directly into our model’s logic, each new test would require changing the model’s source code, leading to:

  • Inefficiency: Changing source code for each test slows down the experimentation process.

  • Increased Risk of Errors: Constantly modifying the source increases the chance of introducing bugs or unintended side effects.

  • Lack of Reproducibility: Without a clear separation between code and configuration, tracking and reproducing the exact configuration used for a particular result becomes challenging.

By setting these values externally via the model_config parameter, we make the model flexible and adaptable to different test scenarios. This approach also integrates seamlessly with MLflow’s evaluation tools, such as mlflow.evaluate(), which allows you to compare model outputs across different configurations systematically.

Defining the Model Configuration

The configuration consists of two main sections:

  1. Models: This section defines agent-specific configurations, such as the judge and oracle roles in this example. Each agent has:

    • An endpoint: Specifies the model type or service being used for this agent.

    • An instruction: Defines the role and responsibilities of the agent (e.g., answering questions, evaluating responses).

    • Temperature and Max Tokens: Controls the generation variability (temperature) and token limit for responses.

  2. General Configuration: Additional settings for the overall behavior of the model, such as how user responses should be framed for subsequent agents.

Note

There are two options available for setting a model configuration: directly within the logging code (shown below) or by writing a configuration file in yaml format to a local location whose path can be specified when defining the model_config argument during logging. To learn more about how the model_config parameter is utilized, see the guide on model_config usage.

Here’s how we set the configuration for our agents:

model_config = {
    "models": {
        "judge": {
            "endpoint": "databricks-meta-llama-3-1-405b-instruct",
            "instruction": (
                "You are an evaluator of answers provided by others. Based on the context of both the question and the answer, "
                "provide a corrected answer if it is incorrect; otherwise, enhance the answer with additional context and explanation."
            ),
            "temperature": 0.5,
            "max_tokens": 2000,
        },
        "oracle": {
            "endpoint": "databricks-mixtral-8x7b-instruct",
            "instruction": (
                "You are a knowledgeable source of information that excels at providing detailed, but brief answers to questions. "
                "Provide an answer to the question based on the information provided."
            ),
            "temperature": 0.9,
            "max_tokens": 5000,
        },
    },
    "configuration": {
        "user_response_instruction": "Can you evaluate and enhance this answer with the provided contextual history?"
    },
}

Benefits of External Configuration

  • Flexibility: The decoupled configuration allows us to easily switch or adjust model behavior without modifying the core logic. For example, we can change the model’s instructions or adjust the temperature to test different levels of creativity in the responses.

  • Scalability: As more agents are added to the system or new roles are introduced, we can extend this configuration without cluttering the model’s code. This separation keeps the codebase cleaner and more maintainable.

  • Reproducibility and Comparison: By keeping configuration external, we can log the specific settings used in each run with MLflow. This makes it easier to reproduce results and compare different experiments, ensuring a robust evaluation and adjudication process to select the best performing configuration.

With the configuration in place, we’re now ready to log the model and run experiments using these settings. By leveraging MLflow’s powerful tracking and logging features, we’ll be able to manage the experiments efficiently and extract valuable insights from the agent’s responses.

Defining an Input Example

Before logging our model, it’s important to provide an input_example that demonstrates how to interact with the model. This example serves several key purposes:

  • Validation at Logging Time: Including an input_example allows MLflow to execute the predict method using this example during the logging process. This helps validate that your model can handle the expected input format and catch any issues early.

  • UI Representation: The input_example is displayed in the MLflow UI under the model’s artifacts. This provides a convenient reference for users to understand the expected input structure when interacting with the deployed model.

By providing an input example, you ensure that your model is tested with real data, increasing confidence that it will behave as expected when deployed.

Tip

When defining your GenAI application using the mlflow.pyfunc.ChatModel, a default placeholder input example will be used if none is provided. If you notice an unfamiliar or generic input example in the MLflow UI’s artifact viewer, it’s likely the default placeholder assigned by the system. To avoid this, ensure you specify a custom input example when saving your model.

Here’s the input example we’ll use:

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "What is a good recipe for baking scones that doesn't require a lot of skill?",
        }
    ]
}

This example represents a user asking for an easy scone recipe. It aligns with the input structure expected by our BasicAgent model, which processes a list of messages where each message includes a role and content.

Benefits of Providing an Input Example:

  • Execution and Validation: MLflow will pass this input_example to the model’s predict method during logging to ensure that it can process the input without errors. Any issues with input handling, such as incorrect data types or missing fields, will be caught at this stage, saving you time debugging later.

  • User Interface Display: The input_example will be visible in the MLflow UI within the model artifact view section. This helps users understand the format of input data the model expects, making it easier to interact with the model once it’s deployed.

  • Deployment Confidence: By validating the model with an example input upfront, you gain additional assurance that the model will function correctly in a production environment, reducing the risk of unexpected behavior after deployment.

Including an input_example is a simple yet powerful step to verify that your model is ready for deployment and will behave as expected when receiving input from users.

Logging and Loading our custom Agent

To log and load the model using MLflow, use:

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        "model",
        python_model=BasicAgent(),
        model_config=model_config,
        input_example=input_example,
    )

loaded = mlflow.pyfunc.load_model(model_info.model_uri)

response = loaded.predict(
    {
        "messages": [
            {
                "role": "user",
                "content": "What is the best material to make a baseball bat out of?",
            }
        ]
    }
)

Conclusion

In this tutorial, you have explored the process of creating a custom GenAI chat agent using MLflow’s mlflow.pyfunc.ChatModel class. We demonstrated how to implement a flexible, scalable, and standardized approach to managing the deployment of GenAI applications, enabling you to harness the latest advancements in AI, even for libraries and frameworks that are not yet natively supported with a named flavor in MLflow.

By using ChatModel instead of the more generic PythonModel, you can avoid many of the common pitfalls associated with deploying GenAI by leveraging the benefits of immutable signature interfaces that are consistent across any of your deployed GenAI interfaces, simplifying the use of all of your solutions by providing a consistent experience.

Key takeaways from this tutorial include:

  • Tracing and Monitoring: By integrating tracing directly into the model, you gain valuable insights into the internal workings of your application, making debugging and optimization more straightforward. Both the decorator and fluent API approaches offer versatile ways to manage tracing for critical operations.

  • Flexible Configuration Management: Decoupling configurations from your model code ensures that you can rapidly test and iterate without modifying source code. This approach not only streamlines experimentation but also enhances reproducibility and scalability as your application evolves.

  • Standardized Input and Output Structures: Leveraging the static signatures of ChatModel simplifies the complexities of deploying and serving GenAI models. By adhering to established standards, you reduce the friction typically associated with integrating and validating input/output formats.

  • Avoiding Common Pitfalls: Throughout the implementation, we highlighted best practices to avoid common issues, such as proper handling of secrets, validating input examples, and understanding the nuances of loading context. Following these practices ensures that your model remains secure, robust, and reliable in production environments.

  • Validation and Deployment Readiness: The importance of validating your model before deployment cannot be overstated. By using tools like mlflow.models.validate_serving_input(), you can catch and resolve potential deployment issues early, saving time and effort during the production deployment process.

As the landscape of Generative AI continues to evolve, building adaptable and standardized models will be crucial to leveraging the exciting and powerful capabilities that will be unlocked in the months and years ahead. The approach covered in this tutorial equips you with a robust framework for integrating and managing GenAI technologies within MLflow, empowering you to develop, track, and deploy sophisticated AI solutions with ease.

We encourage you to extend and customize this foundational example to suit your specific needs and explore further enhancements. By leveraging MLflow’s growing capabilities, you can continue to refine your GenAI models, ensuring they deliver impactful and reliable results in any application.