Optimize Prompts (Experimental)
MLflow allows you to plug your prompts into advanced prompt optimization techniques through MLflow's unified interface using the mlflow.genai.optimize_prompt()
API.
This feature helps you improve your prompts automatically by leveraging evaluation metrics and labeled data. MLflow provides built-in support for optimization algorithms such as DSPy's MIPROv2 algorithm. You can also implement custom optimization algorithms by extending MLflow's base optimizer class BasePromptOptimizer.
- Unified Interface: Access to the state-of-the-art prompt optimization algorithms through a neutral interface.
- Extensible: Create custom optimization algorithms by extending base optimizer classes.
- Prompt Management: Integrate with MLflow Prompt Registry to gain reusability, version control and lineage.
- Evaluation: Evaluate prompt performance comprehensively with MLflow's evaluation features.
Optimization Overview
In order to use mlflow.genai.optimize_prompt()
API, you need to prepare the following:
Component | Example |
---|---|
Prompt Registered in MLflow |
|
Scorer Evaluates prompt quality |
|
Data Inputs & expectations |
|
Target LLM Model to optimize for |
|
Optimizer Config Optimization settings |
|
Getting Started
Here's a simple example of optimizing a question-answering prompt:
As a prerequisite, you need to install DSPy.
$ pip install dspy>=2.6.0 mlflow>=3.1.0
Then, run the following code to register the initial prompt and optimize it.
import os
from typing import Any
import mlflow
from mlflow.genai.scorers import scorer
from mlflow.genai.optimize import OptimizerConfig, LLMParams
os.environ["OPENAI_API_KEY"] = "<YOUR_OPENAI_API_KEY>"
# Define a custom scorer function to evaluate prompt performance with the @scorer decorator.
# The scorer function for optimization can take inputs, outputs, and expectations.
@scorer
def exact_match(expectations: dict[str, Any], outputs: dict[str, Any]) -> bool:
return expectations["answer"] == outputs["answer"]
# Register the initial prompt
initial_template = """
Answer to this math question: {{question}}.
Return the result in a JSON string in the format of {"answer": "xxx"}.
"""
prompt = mlflow.genai.register_prompt(
name="math",
template=initial_template,
)
# The data can be a list of dictionaries, a pandas DataFrame, or an mlflow.genai.EvaluationDataset
# It needs to contain inputs and expectations where each row is a dictionary.
train_data = [
{
"inputs": {"question": "Given that $y=3$, evaluate $(1+y)^y$."},
"expectations": {"answer": "64"},
},
{
"inputs": {
"question": "The midpoint of the line segment between $(x,y)$ and $(-9,1)$ is $(3,-5)$. Find $(x,y)$."
},
"expectations": {"answer": "(15,-11)"},
},
{
"inputs": {
"question": "What is the value of $b$ if $5^b + 5^b + 5^b + 5^b + 5^b = 625^{(b-1)}$? Express your answer as a common fraction."
},
"expectations": {"answer": "\\frac{5}{3}"},
},
{
"inputs": {"question": "Evaluate the expression $a^3\\cdot a^2$ if $a= 5$."},
"expectations": {"answer": "3125"},
},
{
"inputs": {"question": "Evaluate $\\lceil 8.8 \\rceil+\\lceil -8.8 \\rceil$."},
"expectations": {"answer": "17"},
},
]
eval_data = [
{
"inputs": {
"question": "The sum of 27 consecutive positive integers is $3^7$. What is their median?"
},
"expectations": {"answer": "81"},
},
{
"inputs": {"question": "What is the value of $x$ if $x^2 - 10x + 25 = 0$?"},
"expectations": {"answer": "5"},
},
{
"inputs": {
"question": "If $a\\ast b = 2a+5b-ab$, what is the value of $3\\ast10$?"
},
"expectations": {"answer": "26"},
},
{
"inputs": {
"question": "Given that $-4$ is a solution to $x^2 + bx -36 = 0$, what is the value of $b$?"
},
"expectations": {"answer": "-5"},
},
]
# Optimize the prompt
result = mlflow.genai.optimize_prompt(
target_llm_params=LLMParams(model_name="openai/gpt-4.1-mini"),
prompt=prompt,
train_data=train_data,
eval_data=eval_data,
scorers=[exact_match],
optimizer_config=OptimizerConfig(
num_instruction_candidates=8,
max_few_show_examples=2,
),
)
# The optimized prompt is automatically registered as a new version
print(result.prompt.uri)
In the example above the average performance score increased from 0 to 0.5. After the optimization process is completed, you can visit the MLflow Prompt Registry page and see the optimized prompt.
Note that the optimized prompt of mlflow.genai.optimize_prompt()
expects the output to be a JSON string.
Therefore, you need to parse the output using json.loads
in your application. See Load and Use the Prompt for how to load the optimized prompt.
import mlflow
import json
import openai
def predict(question: str, prompt_uri: str) -> str:
prompt = mlflow.genai.load_prompt(prompt_uri)
content = prompt.format(question=question)
completion = openai.chat.completions.create(
model="gpt-4.1-mini",
messages=[{"role": "user", "content": content}],
temperature=0.1,
)
return json.loads(completion.choices[0].message.content)["answer"]
Configuration
You can customize the optimization process using OptimizerConfig
, which includes the following parameters:
- algorithm: The optimization algorithm to use. Can be a string (e.g., "DSPy/MIPROv2") or a custom optimizer class. Default: "DSPy/MIPROv2"
- num_instruction_candidates: The number of candidate instructions to try. Default: 6
- max_few_show_examples: The maximum number of examples to show in few-shot demonstrations. Default: 6
- optimizer_llm: The LLM to use for optimization. Default: None (uses target LLM)
- extract_instructions: Whether to extract instructions from the initial prompt template. Default: True
- verbose: Whether to show optimizer logs during optimization. Default: False
- autolog: Whether to log the optimization parameters, datasets and metrics. If set to True, a MLflow run is automatically created to store them. Default: False
See mlflow.genai.OptimizerConfig()
for more details.
Custom Optimizers
MLflow supports creating custom prompt optimization algorithms by extending the base optimizer classes. This allows you to implement domain-specific optimization strategies or integrate with other optimization libraries.
BasePromptOptimizer
For custom optimization logic, extend the BasePromptOptimizer class:
from mlflow.genai.optimize import BasePromptOptimizer, OptimizerConfig, OptimizerOutput
from mlflow.genai.optimize.types import LLMParams
from mlflow.entities.model_registry import PromptVersion
from mlflow.genai.scorers import Scorer
from typing import Optional, Callable, Any
import pandas as pd
class CustomOptimizer(BasePromptOptimizer):
# Inherit the BasePromptOptimizer class and implement the `optimize` method.
def optimize(
self,
prompt: PromptVersion,
target_llm_params: LLMParams,
train_data: pd.DataFrame,
scorers: list[Scorer],
objective: Optional[Callable[[dict[str, Any]], float]] = None,
eval_data: Optional[pd.DataFrame] = None,
) -> OptimizerOutput:
# Implement your custom optimization logic here
optimized_template = f"Please answer accurately: {prompt.template}"
return OptimizerOutput(
optimized_prompt=optimized_template,
optimizer_name="CustomOptimizer",
final_eval_score=0.85,
initial_eval_score=0.75,
)
# Use the custom optimizer
result = mlflow.genai.optimize_prompt(
target_llm_params=LLMParams(model_name="openai/gpt-4o-mini"),
prompt=prompt,
train_data=train_data,
scorers=[exact_match],
optimizer_config=OptimizerConfig(algorithm=CustomOptimizer),
)
DSPyPromptOptimizer
For DSPy-based optimizations, extend the DSPyPromptOptimizer class, which provides DSPy integration infrastructure:
import dspy
from typing import Callable
from mlflow.genai.optimize import (
DSPyPromptOptimizer,
OptimizerOutput,
format_dspy_prompt,
)
from mlflow.entities.model_registry import PromptVersion
class CustomDSPyOptimizer(DSPyPromptOptimizer):
# Inherit the DSPyPromptOptimizer class and implement the `run_optimization` method.
def run_optimization(
self,
prompt: PromptVersion,
program: dspy.Module,
metric: Callable[[dspy.Example], float],
train_data: list,
eval_data: list,
) -> OptimizerOutput:
# Use DSPy's optimization components with your custom logic
# Example using DSPy's BootstrapFewShot optimizer
optimizer = dspy.BootstrapFewShot(
metric=metric,
max_bootstrapped_demos=self.optimizer_config.max_few_show_examples,
)
# Compile the program
compiled_program = optimizer.compile(
student=program,
trainset=train_data,
)
return OptimizerOutput(
optimized_prompt=format_dspy_prompt(compiled_program),
optimizer_name="BootstrapFewShot",
)
# Use the custom DSPy optimizer
result = mlflow.genai.optimize_prompt(
target_llm_params=LLMParams(model_name="openai/gpt-4o-mini"),
prompt=prompt,
train_data=train_data,
scorers=[exact_match],
optimizer_config=OptimizerConfig(algorithm=CustomDSPyOptimizer),
)
When using custom optimizers, ensure they return an OptimizerOutput
object with the optimized prompt and evaluation scores.
Performance Benchmarks
We are actively working on the benchmarking. These benchmarks results are preliminary and subject to change.
MLflow prompt optimization can improve your application's performance across various tasks. Here are the results from testing MLflow's optimization capabilities on several datasets:
- ARC-Challenge: The ai2_arc dataset contains a set of multiple choice science questions
- GSM8K: The gsm8k dataset contains a set of linguistically diverse grade school math word problems
- MATH: Competition mathematics problems requiring advanced reasoning and problem-solving skills
Dataset | Model | Baseline | Optimized |
---|---|---|---|
MATH | gpt-4.1o-nano | 17.25% | 18.48% |
GSM8K | gpt-4.1o-nano | 21.46% | 49.89% |
ARC-Challenge | gpt-4.1o-nano | 71.42% | 89.25% |
MATH | Llama4-maverick | 33.06% | 33.26% |
GSM8K | Llama4-maverick | 55.80% | 58.22% |
ARC-Challenge | Llama4-maverick | 0.17% | 93.17% |
The results above are benchmarks tested against gpt-4.1o-nano
and Llama4-maverick
with DSPy's MIPROv2 algorithm and default settings, using specific evaluation metrics for each task.
The results might change if you use a different model, configuration, dataset, or starting prompt(s).
These results show that MLflow's prompt optimization can solve many of the challenges, delivering measurable performance gains with minimal effort.
FAQ
What are the supported Dataset formats?
The training and evaluation data for the mlflow.genai.optimize_prompt()
API can be a list of dictionaries, a pandas DataFrame, a spark DataFrame, or an mlflow.genai.EvaluationDataset.
In any case, the data needs to contain inputs and expectations columns that contains a dictionary of input fields and expected output fields.
Each inputs or expectations dictionary can contain primitive types, lists, nested dictionaries, and Pydantic models. Data types are inferred from the first row of the dataset.
# ✅ OK
[
{
"inputs": {"question": "What is the capital of France?"},
"expectations": {"answer": "Paris"},
},
]
# ✅ OK
[
{
"inputs": {"question": "What are the three largest cities of Japan?"},
"expectations": {"answer": ["Tokyo", "Osaka", "Nagoya"]},
},
]
# ✅ OK
from pydantic import BaseModel
class Country(BaseModel):
name: str
capital: str
population: int
[
{
"inputs": {"question": "What is the capital of France?"},
"expectations": {
"answer": Country(name="France", capital="Paris", population=68000000)
},
},
]
# ❌ NG
[
{
"inputs": "What is the capital of France?",
"expectations": "Paris",
},
]
How to combine multiple scorers?
While the mlflow.genai.optimize_prompt()
API accepts multiple scorers, the optimizer needs to combine them into a single score during the optimization process.
By default, the optimizer computes the total score of all scorers with numeric or boolean values.
If you want to use a custom aggregation function or use scorers that return non-numeric values, you can pass a custom aggregation function to the objective
parameter.
@scorer
def safeness(outputs: dict[str, Any]) -> bool:
return "death" not in outputs["answer"].lower()
@scorer
def relevance(expectations: dict[str, Any], outputs: dict[str, Any]) -> bool:
return expectations["answer"] in outputs["answer"]
def objective(scores: dict[str, Any]) -> float:
if not scores["safeness"]:
return -1
return scores["relevance"]
result = mlflow.genai.optimize_prompt(
target_llm_params=LLMParams(model_name="openai/gpt-4.1-mini"),
prompt=prompt,
train_data=train_data,
eval_data=eval_data,
scorers=[safeness, relevance],
objective=objective,
)
How to create custom optimizers?
MLflow provides two base classes for creating custom optimizers:
- BasePromptOptimizer: For custom optimization logic. In order to customize, you must implement your own
optimize
method. - DSPyPromptOptimizer: For DSPy-based optimizations. Customization of this optimizer involves building a
run_optimization
method. MLflow will handle the required DSPy setup for utilizing your custom interface.
Choose BasePromptOptimizer when you want complete control over the optimization process, or DSPyPromptOptimizer when you want to leverage DSPy's ecosystem while customizing the optimization strategy.