Tutorial: Custom GenAI Models using ChatModel
Starting in MLflow 3.0.0, we recommend ResponsesAgent
instead of ChatModel
. See more details in the ResponsesAgent Introduction.
The rapidly evolving landscape of Generative Artificial Intelligence (GenAI) presents exciting opportunities and integration challenges.
To leverage the latest GenAI advancements effectively, developers need a framework that balances flexibility with standardization.
MLflow addresses this need with the mlflow.pyfunc.ChatModel
class introduced in
version 2.11.0, providing a
consistent interface for GenAI applications while simplifying deployment and testing.
Choosing Between ChatModel and PythonModel
When building GenAI applications in MLflow, it's essential to choose the right model abstraction that balances ease of use with the level of
customization you need. MLflow offers two primary classes for this purpose: mlflow.pyfunc.ChatModel
and mlflow.pyfunc.PythonModel
. Each has its own strengths and trade-offs, making it crucial to understand which one best suits your use case.
ChatModel | PythonModel | |
---|---|---|
When to use | Use when you want to develop and deploy a conversational model with standard chat schema compatible with OpenAI spec. | Use when you want full control over the model's interface or customize every aspect of your model's behavior. |
Interface | Fixed to OpenAI's chat schema. | Full control over the model's input and output schema. |
Setup | Quick. Works out of the box for conversational applications, with pre-defined model signature and input example. | Custom. You need to define model signature or input example yourself. |
Complexity | Low. Standardized interface simplified model deployment and integration. | High. Deploying and integrating the custom PythonModel may not be straightforward. E.g., The model needs to handle Pandas DataFrames as MLflow converts input data to DataFrames before passing it to PythonModel. |
Purpose of this tutorial
This tutorial will guide you through the process of creating a custom chat agent using MLflow's mlflow.pyfunc.ChatModel
class.
By the end of this tutorial you will:
- Integrate MLflow Tracing into a custom
mlflow.pyfunc.ChatModel
instance. - Customize your model using the
model_config
parameter withinmlflow.pyfunc.log_model()
. - Leverage standardized signature interfaces for simplified deployment.
- Recognize and avoid common pitfalls when extending the
mlflow.pyfunc.ChatModel
class.
Prerequisites
- Familiarity with MLflow logging APIs and GenAI concepts.
- MLflow version 2.11.0 or higher installed for use of
mlflow.pyfunc.ChatModel
. - MLflow version 2.14.0 or higher installed for use of MLflow Tracing.
This tutorial uses the Databricks Foundation Model APIs purely as an example of interfacing with an external service. You can easily swap the provider example to use any managed LLM hosting service with ease (Amazon Bedrock, Azure AI Studio, OpenAI, Anthropic, and many others).
Core Concepts
- Tracing
- Customization
- Standardization
- Pitfalls
Tracing Customization for GenAI
MLflow Tracing allows you to monitor and log the execution of your model's methods, providing valuable insights during debugging and performance optimization.
In our example BasicAgent
implementation we utilize two separate APIs for the initiation of trace spans: the decorator API and the fluent API.
Decorator API
@mlflow.trace
def _get_system_message(self, role: str) -> Dict:
if role not in self.models:
raise ValueError(f"Unknown role: {role}")
instruction = self.models[role]["instruction"]
return ChatMessage(role="system", content=instruction).to_dict()
Using the @mlflow.trace
tracing decorator is the simplest way to add tracing functionality to functions and methods. By default, a span that is generated from
the application of this decorator will utilize the name of the function as the name of the span. It is possible to override this naming, as well as
other parameters associated with the span, as follows:
@mlflow.trace(name="custom_span_name", attributes={"key": "value"}, span_type="func")
def _get_system_message(self, role: str) -> Dict:
if role not in self.models:
raise ValueError(f"Unknown role: {role}")
instruction = self.models[role]["instruction"]
return ChatMessage(role="system", content=instruction).to_dict()
It is always advised to set a human-readable name for any span that you generate, particularly if you are instrumenting private or generically named functions or methods. The MLflow Trace UI will display the name of the function or method by default, which can be confusing to follow if your functions and methods are ambiguously named.
Fluent API
The fluent APIs
context handler implementation for initiating spans is useful when you need full control of the logging of each aspect of the span's data.
The example from our application for ensuring that we're capturing the parameters that are set when loading the model via the load_context
method is
shown below. We are pulling from the instance attributes self.models_config
and self.models
to set the attributes of the span.
with mlflow.start_span("Audit Agent") as root_span:
root_span.set_inputs(messages)
attributes = {**params.to_dict(), **self.models_config, **self.models}
root_span.set_attributes(attributes)
# More span manipulation...
Traces in the MLflow UI
After running our example that includes these combined usage patterns for trace span generation and instrumentation,