Tasks in MLflow Transformers Flavor
This page provides an overview of how to use the task
parameter in the MLflow Transformers flavor to control the inference interface of the model.
Overview
In the MLflow Transformers flavor, task
plays a crucial role in determining the input and output format of the model. The task
is a fundamental concept in the Transformers library, which describe the structure of each model’s API (inputs and outputs) and are used to determine which Inference API and widget we want to display for any given model.
MLflow utilizes this concept to determine the input and output format of the model, persists the correct Model Signature, and provides a consistent Pyfunc Inference API for serving different types of models. Additionally, on top of the native Transformers task types, MLflow defines a few additional task types to support more complex use cases, such as chat-style applications.
Native Transformers Task Types
For native Transformers tasks, MLflow will automatically infer the task type from the pipeline when you save a pipeline with mlflow.transformers.log_model()
. You can also specify the task type explicitly by passing the task
parameter. The full list of supported task types is available in the Transformers documentation, but note that not all task types are supported in MLflow.
import mlflow
import transformers
pipeline = transformers.pipeline("text-generation", model="gpt2")
with mlflow.start_run():
model_info = mlflow.transformers.save_model(
transformers_model=pipeline,
artifact_path="model",
save_pretrained=False,
)
print(f"Inferred task: {model_info.flavors['transformers']['task']}")
# >> Inferred task: text-generation
Advanced Tasks for OpenAI-Compatible Inference
In addition to the native Transformers task types, MLflow defines a few additional task types. Those advanced task types allows you to extend the Transformers pipeline with OpenAI-compatible inference interface, to serve models for specific use cases. In addition to the native Transformers task types, MLflow defines several additional task types. These advanced task types allow you to extend the Transformers pipeline with an OpenAI-compatible inference interface to serve models for specific use cases.
For example, the Transformers text-generation
pipeline inputs and outputs a single string or a list of strings. However, when serving a model, it is often necessary to have a more structured input and output format. For instance, in a chat-style application, the input may be a list of messages.
To support these use cases, MLflow defines a set of advanced task types prefixed with llm/v1
:
"llm/v1/chat"
for chat-style applications"llm/v1/completions"
for generic completions"llm/v1/embeddings"
for text embeddings generation
The required step to use these advanced task types is just to specify the task
parameter as an llm/v1
task when logging the models.
import mlflow
with mlflow.start_run():
mlflow.transformers.log_model(
transformers_model=pipeline,
artifact_path="model",
task="llm/v1/chat", # <= Specify the llm/v1 task type
# Optional, recommended for large models to avoid creating a local copy of the model weights
save_pretrained=False,
)
Note
This feature is only available in MLflow 2.11.0 and above. Also, the llm/v1/chat
task type is only available for models saved with transformers >= 4.34.0
.
Input and Output Formats
Task |
Supported pipeline |
Input |
Output |
---|---|---|---|
|
|
Returns a Chat Completion object in the json format. |
|
|
|
Returns a Completion object in the json format. |
|
|
|
Returns a list of Embedding object. Additionally, the model returns |
Note
The Completion API is considered as legacy, but it is still supported in MLflow for backward compatibility. We recommend using the Chat API for compatibility with the latest APIs from OpenAI and other model providers.
Code Example of Using llm/v1
Tasks
The following code snippet demonstrates how to log a Transformers pipeline with the llm/v1/chat
task type, and use the model for chat-style inference. Check out the notebook tutorial to see more examples in action!
import mlflow
import transformers
pipeline = transformers.pipeline("text-generation", "gpt2")
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=pipeline,
artifact_path="model",
task="llm/v1/chat",
input_example={
"messages": [
{"role": "system", "content": "You are a bot."},
{"role": "user", "content": "Hello, how are you?"},
]
},
save_pretrained=False,
)
# Model metadata logs additional field "inference_task"
print(model_info.flavors["transformers"]["inference_task"])
# >> llm/v1/chat
# The original native task type is also saved
print(model_info.flavors["transformers"]["task"])
# >> text-generation
# Model signature is set to the chat API spec
print(model_info.signature)
# >> inputs:
# >> ['messages': Array({content: string (required), name: string (optional), role: string (required)}) (required), 'temperature': double (optional), 'max_tokens': long (optional), 'stop': Array(string) (optional), 'n': long (optional), 'stream': boolean (optional)]
# >> outputs:
# >> ['id': string (required), 'object': string (required), 'created': long (required), 'model': string (required), 'choices': Array({finish_reason: string (required), index: long (required), message: {content: string (required), name: string (optional), role: string (required)} (required)}) (required), 'usage': {completion_tokens: long (required), prompt_tokens: long (required), total_tokens: long (required)} (required)]
# >> params:
# >> None
# The model can be served with the OpenAI-compatible inference API
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
prediction = pyfunc_model.predict(
{
"messages": [
{"role": "system", "content": "You are a bot."},
{"role": "user", "content": "Hello, how are you?"},
],
"temperature": 0.5,
"max_tokens": 200,
}
)
print(prediction)
# >> [{'choices': [{'finish_reason': 'stop',
# >> 'index': 0,
# >> 'message': {'content': 'I'm doing well, thank you for asking.', 'role': 'assistant'}},
# >> 'created': 1719875820,
# >> 'id': '355c4e9e-040b-46b0-bf22-00e93486100c',
# >> 'model': 'gpt2',
# >> 'object': 'chat.completion',
# >> 'usage': {'completion_tokens': 7, 'prompt_tokens': 13, 'total_tokens': 20}}]
Note that the input and output modifications only apply when the model is loaded with mlflow.pyfunc.load_model()
(e.g. when
serving the model with the mlflow models serve
CLI tool). If you want to load just the raw pipeline, you can
use mlflow.transformers.load_model()
.
Provisioned Throughput on Databricks Model Serving
Provisioned Throughput on Databricks Model Serving is a capability that optimizes inference performance for foundation models with performance guarantees. To serve Transformers models with provisioned throughput, specify llm/v1/xxx
task type when logging the model. MLflow logs the required metadata to enable provisioned throughput on Databricks Model Serving.
Tip
When logging large models, you can use save_pretrained=False
to avoid creating a local copy of the model weights for saving time and disk space. Please refer to the documentation for more details.
FAQ
How to override the default query parameters for the OpenAI-compatible inference?
When serving the model saved with the llm/v1
task type, MLflow uses the same default value as OpenAI APIs for the parameters like temperature
and stop
. You can override them by either passing the values at inference time, or by setting different default values when logging the model.
At inference time: You can pass the parameters as part of the input dictionary when calling the
predict()
method, just like how you pass the input messages.When logging the model: You can override the default values for the parameters by saving a
model_config
parameter when logging the model.
with mlflow.start_run():
model_info = mlflow.transformers.log_model(
transformers_model=pipeline,
artifact_path="model",
task="llm/v1/chat",
model_config={
"temperature": 0.5, # <= Set the default temperature
"stop": ["foo", "bar"], # <= Set the default stop sequence
},
save_pretrained=False,
)
Attention
The stop
parameter can be used to specify the stop sequence for the llm/v1/chat
and llm/v1/completions
tasks. We emulate the behavior of the stop
parameter in the OpenAI APIs by passing the stopping_criteria to the Transformers pipeline, with the token IDs of the given stop sequence. However, the behavior may not be stable because the tokenizer does not always generate the same token IDs for the same sequence in different sentences, especially for sentence-piece
based tokenizers.