Skip to main content

7 posts tagged with "genai"

View All Tags

· 62 min read
Jas Bali

Thumbnail

In this blog post, we delve into the integration of AWS Bedrock Agent as a ChatModel within MLflow, focusing on how to leverage Bedrock's Action Groups and Knowledge Bases to build a conversational AI application. The blog will guide you through setting up the Bedrock Agent, configuring Action Groups to enable custom actions with Lambda, and utilizing knowledge bases for context-aware interactions. A special emphasis is placed on implementing tracing within MLflow.By the end of this article, you'll have a good understanding of how to combine AWS Bedrock's advanced features with MLflow's capabilities such as agent request tracing, model tracking and consistent signatures for input examples.

What is AWS Bedrock?

Amazon Bedrock is a managed service by AWS that simplifies the development of generative AI applications. It provides access to a variety of foundation models (FMs) from leading AI providers through a single API, enabling developers to build and scale AI solutions securely and efficiently.

Key Components Relevant to This Integration:

Bedrock Agent: At a high level, a bedrock agent is an abstraction within bedrock that consists of a foundation model, action groups and knowledge bases.

Action Groups: These are customizable sets of actions that define what tasks the Bedrock Agent can perform. Action Groups consist of an OpenAPI Schema and the corresponding Lambda functions that will be used to execute tool calls. The OpenAPI Schema is used to define APIs available for the agent to invoke and complete tasks.

Knowledge Bases: Amazon Bedrock supports the creation of Knowledge Bases to implement Retrieval Augmented Generation workflows. It consists of data sources (on S3 or webpages) and a vector store that contains the embedded references to this data.

Bedrock's Agent execution process and the corresponding tracing for Agent instrumentation is grouped as follows:

Pre-processing This step validates, contextualizes and categorizes user input.

Orchestration This step handles the interpretation of user inputs, deciding when to and which tasks to perform, and iteratively refines responses

Post-processing (Optional) This step formats the final response before returning to the user.

Traces Each step above has an execution trace, which consists of rationale, actions, queries and observations at each step of the agent's response. This includes both the inputs and outputs of action groups and knowledge base queries.

We will look at these traces in detail below.

What is a ChatModel in MLflow?

The ChatModel class is specifically designed to make it easier to implement models that are compatible with popular large language model (LLM) chat APIs. It enables you to seamlessly bring in your own models or agents and leverage MLflow's functionality, even if those models aren't natively supported as a flavor in MLflow. Additionally, It provides default signatures, which are static for ChatModel, unlike PythonModel.

In the following sections, we will use ChatModel to wrap the Bedrock Agent.

For more detailed information about ChatModel, you can read the MLflow documentation here and here

Setting up AWS Bedrock Agent with an Action group

In this section, we will deploy all components of a bedrock agent so that we can invoke it as a ChatModel in MLflow.

Prerequisites

You will need to setup following items (either via the AWS console or SDKs):

  • Setting up role for the agent and Lambda function. Example
  • Create/deploy the agent. Example
    • Important: Save the agent ID here as we will need this below.
  • Creating a Lambda function. Example
  • Configuring IAM permissions for agent-Lambda interaction. Example and Example
  • Creating an action group to link the agent and Lambda. Example
    • Important:Save the agent alias ID here as we will need this below.
  • Deploy Bedrock agent with an alias. Example
In our case, we are going to deploy the following example action group, which calculates the next optimal departure date for a Hohmann transfer from Earth to Mars, based on the spacecraft's mass and specific impulse.

OpenAPI schema for Action Groups

As described above, here is the OpenAPI Schema for our example action group:

openapi: 3.0.0
info:
title: Time API
version: 1.0.0
description: API to get the next optimal departure date for a Hohmann transfer from Earth to Mars.
paths:
/get-next-mars-launch-window:
get:
summary: Gets the next optimal launch window to Mars.
description: Gets the next optimal launch window to Mars.
operationId: getNextMarsLaunchWindow
parameters:
- name: total_mass
in: query
description: Total mass of the spacecraft including fuel (kg)
required: true
schema:
type: string
- name: dry_mass
in: query
description: Mass of the spacecraft without fuel (kg).
required: true
schema:
type: string
- name: specific_impulse
in: query
description: Specific impulse of the propulsion system (s).
required: true
schema:
type: string
responses:
"200":
description: The next optimal departure date for a Hohmann transfer from Earth to Mars, based on the spacecraft's mass and specific impulse.
content:
"application/json":
schema:
type: object
properties:
next_launch_window:
type: string
description: Next Mars Launch Window

Action groups - Lamda function

Here is the code deployment for action group's example Lambda:

import json
import math
from datetime import datetime, timedelta


def lambda_handler(event, context):
def _calculate_optimal_departure_window(
total_mass, dry_mass, specific_impulse
):
"""
Calculate the next optimal departure date for a Hohmann transfer from Earth to Mars,
based on the spacecraft's mass and specific impulse.

Parameters:
- total_mass (float): Total mass of the spacecraft including fuel (kg).
- dry_mass (float): Mass of the spacecraft without fuel (kg).
- specific_impulse (float): Specific impulse of the propulsion system (s).

Returns:
- dict: {
'next_launch_date': datetime,
'synodic_period_days': float,
'transfer_time_days': float,
'delta_v_available_m_s': float,
'delta_v_required_m_s': float,
'is_feasible': bool
}
"""
current_date = None
# Constants
G0 = 9.80665 # m/s^2, standard gravity
MU_SUN = (
1.32712440018e20 # m^3/s^2, standard gravitational parameter for the Sun
)
AU = 1.496e11 # meters, astronomical unit
EARTH_ORBITAL_PERIOD = 365.25 # days
MARS_ORBITAL_PERIOD = 686.98 # days
SYNODIC_PERIOD = 1 / abs((1 / EARTH_ORBITAL_PERIOD) - (1 / MARS_ORBITAL_PERIOD))
TRANSFER_TIME = 259 # days, approximate duration of Hohmann transfer
BASE_LAUNCH_DATE = datetime(2020, 7, 1) # A reference past launch window date

# Orbital Radii (assuming circular orbits for simplicity)
r1 = AU # Earth's orbital radius in meters
r2 = 1.524 * AU # Mars' orbital radius in meters

# Calculate Required Delta-V for Hohmann Transfer
# Using vis-viva equation for Hohmann transfer
def calculate_hohmann_delta_v(mu, r_start, r_end):
# Velocity of departure orbit (Earth)
v_start = math.sqrt(mu / r_start)
# Velocity of transfer orbit at departure
a_transfer = (r_start + r_end) / 2
v_transfer_start = math.sqrt(mu * (2 / r_start - 1 / a_transfer))
delta_v1 = v_transfer_start - v_start

# Velocity of arrival orbit (Mars)
v_end = math.sqrt(mu / r_end)
# Velocity of transfer orbit at arrival
v_transfer_end = math.sqrt(mu * (2 / r_end - 1 / a_transfer))
delta_v2 = v_end - v_transfer_end

return delta_v1, delta_v2

delta_v1, delta_v2 = calculate_hohmann_delta_v(MU_SUN, r1, r2)
delta_v_required = abs(delta_v1) + abs(delta_v2) # Total delta-v in m/s

# Delta-V using Tsiolkovsky Rocket Equation
if dry_mass <= 0 or total_mass <= dry_mass:
raise ValueError("Total mass must be greater than dry mass.")

delta_v_available = (
specific_impulse * G0 * math.log(total_mass / dry_mass)
) # m/s

is_feasible = delta_v_available >= delta_v_required

if current_date is None:
current_date = datetime.now()

days_since_base = (current_date - BASE_LAUNCH_DATE).days
if days_since_base < 0:
# Current date is before the base launch date
next_launch_date = BASE_LAUNCH_DATE
else:
synodic_periods_passed = days_since_base / SYNODIC_PERIOD
synodic_periods_passed_int = math.floor(synodic_periods_passed)
next_launch_date = BASE_LAUNCH_DATE + timedelta(
days=(synodic_periods_passed_int + 1) * SYNODIC_PERIOD
)

next_launch_date = next_launch_date.replace(
hour=0, minute=0, second=0, microsecond=0
)

return {
"next_launch_date": next_launch_date,
"synodic_period_days": SYNODIC_PERIOD,
"transfer_time_days": TRANSFER_TIME,
"delta_v_available_m_s": delta_v_available,
"delta_v_required_m_s": delta_v_required,
"is_feasible": is_feasible,
}

query_params = {
event["name"]: event["value"] for event in event.get("parameters", [])
}

total_mass = float(query_params.get("total_mass"))
dry_mass = float(query_params.get("dry_mass"))
specific_impulse = float(query_params.get("specific_impulse"))

response = {
"next_launch_window": _calculate_optimal_departure_window(
total_mass, dry_mass, specific_impulse
)
}

response_body = {"application/json": {"body": json.dumps(response)}}

action_response = {
"actionGroup": event["actionGroup"],
"apiPath": event["apiPath"],
"httpMethod": event["httpMethod"],
"httpStatusCode": 200,
"responseBody": response_body,
}

session_attributes = event["sessionAttributes"]
prompt_session_attributes = event["promptSessionAttributes"]

return {
"messageVersion": "1.0",
"response": action_response,
"sessionAttributes": session_attributes,
"promptSessionAttributes": prompt_session_attributes,
}

Next, we are going to wrap Bedrock agent as a ChatModel so that we can register and load it for inference.

Writing ChatModel for Bedrock agent

Here are the top-level packages used for running the following example locally in Python 3.12.7:

boto3==1.35.31
mlflow==2.16.2

Implementing Bedrock Agent as an MLflow ChatModel with Tracing

import copy
import os
import uuid
from typing import List, Optional

import boto3
import mlflow
from botocore.config import Config
from mlflow.entities import SpanType
from mlflow.pyfunc import ChatModel
from mlflow.types.llm import ChatResponse, ChatMessage, ChatParams, ChatChoice


class BedrockModel(ChatModel):
def __init__(self):
"""
Initializes the BedrockModel instance with placeholder values.

Note:
The `load_context` method cannot create new instance variables; it can only modify existing ones.
Therefore, all instance variables should be defined in the `__init__` method with placeholder values.
"""
self.brt = None
self._main_bedrock_agent = None
self._bedrock_agent_id = None
self._bedrock_agent_alias_id = None
self._inference_configuration = None
self._agent_instruction = None
self._model = None
self._aws_region = None

def __getstate__(self):
"""
Prepares the instance state for pickling.

This method is needed because the `boto3` client (`self.brt`) cannot be pickled.
By excluding `self.brt` from the state, we ensure that the model can be serialized and deserialized properly.
"""
# Create a dictionary of the instance's state, excluding the boto3 client
state = self.__dict__.copy()
del state["brt"]
return state

def __setstate__(self, state):
"""
Restores the instance state during unpickling.

This method is needed to reinitialize the `boto3` client (`self.brt`) after the instance is unpickled,
because the client was excluded during pickling.
"""
self.__dict__.update(state)
self.brt = None

def load_context(self, context):
"""
Initializes the Bedrock client with AWS credentials.

Args:
context: The MLflow context containing model configuration.

Note:
Dependent secret variables must be in the execution environment prior to loading the model;
else they will not be available during model initialization.
"""
self._main_bedrock_agent = context.model_config.get("agents", {}).get(
"main", {}
)
self._bedrock_agent_id = self._main_bedrock_agent.get("bedrock_agent_id")
self._bedrock_agent_alias_id = self._main_bedrock_agent.get(
"bedrock_agent_alias_id"
)
self._inference_configuration = self._main_bedrock_agent.get(
"inference_configuration"
)
self._agent_instruction = self._main_bedrock_agent.get("instruction")
self._model = self._main_bedrock_agent.get("model")
self._aws_region = self._main_bedrock_agent.get("aws_region")

# Initialize the Bedrock client
self.brt = boto3.client(
service_name="bedrock-agent-runtime",
config=Config(region_name=self._aws_region),
aws_access_key_id=os.environ["AWS_ACCESS_KEY"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
aws_session_token=os.environ["AWS_SESSION_TOKEN"],
region_name=self._aws_region,
)

@staticmethod
def _extract_trace_groups(events):
"""
Extracts trace groups from a list of events based on their trace IDs.

Args:
events (list): A list of event dictionaries.

Returns:
dict: A dictionary where keys are trace IDs and values are lists of trace items.
"""
from collections import defaultdict

trace_groups = defaultdict(list)

def find_trace_ids(obj, original_trace, depth=0, parent_key=None):
if depth > 5:
return # Stop recursion after 5 levels if no traceId has been found
if isinstance(obj, dict):
trace_id = obj.get("traceId")
if trace_id:
# Include the parent key as the 'type'
item = {
"type": parent_key,
"data": obj,
"event_order": original_trace.get("trace", {}).get(
"event_order"
),
}
trace_groups[trace_id].append(item)
else:
for key, value in obj.items():
find_trace_ids(
value, original_trace, depth=depth + 1, parent_key=key
)
elif isinstance(obj, list):
for item in obj:
find_trace_ids(item, item, depth=depth + 1, parent_key=parent_key)

find_trace_ids(events, {})
return dict(trace_groups)

@staticmethod
def _get_final_response_with_trace(trace_id_groups: dict[str, list[dict]]):
"""
Processes trace groups to extract the final response and create relevant MLflow spans.

Args:
trace_id_groups (dict): A dictionary of trace groups keyed by trace IDs.

Returns:
str: The final response text extracted from the trace groups.
"""
trace_id_groups_copy = copy.deepcopy(trace_id_groups)
model_invocation_input_key = "modelInvocationInput"

def _create_trace_by_type(
trace_name, _trace_id, context_input, optional_rationale_subtrace
):
@mlflow.trace(
name=trace_name,
attributes={"trace_attributes": trace_id_groups[_trace_id]},
)
def _trace_agent_pre_context(inner_input_trace):
return optional_rationale_subtrace.get("data", {}).get("text")

trace_id_groups_copy[_trace_id].remove(context_input)
_trace_agent_pre_context(context_input.get("data", {}).get("text"))

def _extract_action_group_trace(
_trace_id, trace_group, action_group_invocation_input: dict
):
@mlflow.trace(
name="action-group-invocation",
attributes={"trace_attributes": trace_id_groups[_trace_id]},
)
def _action_group_trace(inner_trace_group):
for _trace in trace_group:
action_group_invocation_output = _trace.get("data", {}).get(
"actionGroupInvocationOutput"
)
if action_group_invocation_output is not None:
action_group_response = str(
{
"action_group_name": action_group_invocation_input.get(
"actionGroupName"
),
"api_path": action_group_invocation_input.get(
"apiPath"
),
"execution_type": action_group_invocation_input.get(
"executionType"
),
"execution_output": action_group_invocation_output.get(
"text"
),
}
)
trace_group.remove(_trace)
return action_group_response

_action_group_trace(str(action_group_invocation_input))

def _extract_knowledge_base_trace(
_trace_id, trace_group, knowledge_base_lookup_input
):
@mlflow.trace(
name="knowledge-base-lookup",
attributes={"trace_attributes": trace_id_groups[_trace_id]},
)
def _knowledge_base_trace(inner_trace_group):
for _trace in trace_group:
knowledge_base_lookup_output = _trace.get("data", {}).get(
"knowledgeBaseLookupOutput"
)
if knowledge_base_lookup_output is not None:
knowledge_base_response = str(
{
"knowledge_base_id": knowledge_base_lookup_input.get(
"knowledgeBaseId"
),
"text": knowledge_base_lookup_input.get("text"),
"retrieved_references": knowledge_base_lookup_output.get(
"retrievedReferences"
),
}
)
trace_group.remove(_trace)
return knowledge_base_response

_knowledge_base_trace(str(trace_group))

def _trace_group_type(
_trace_id, trace_group, _trace, optional_rationale_subtrace
):
trace_name = "observation"
pre_processing_trace_id_suffix = "-pre"
if pre_processing_trace_id_suffix in _trace_id:
trace_name = "agent-initial-context"
else:
for _inner_trace in trace_group:
action_group_invocation_input = _inner_trace.get("data", {}).get(
"actionGroupInvocationInput"
)
if action_group_invocation_input is not None:
action_group_name = action_group_invocation_input.get(
"actionGroupName"
)
trace_name = f"ACTION-GROUP-{action_group_name}"
_create_trace_by_type(
trace_name, _trace_id, _trace, optional_rationale_subtrace
)
_extract_action_group_trace(
_trace_id, trace_group, action_group_invocation_input
)
trace_group.remove(_trace)
knowledge_base_lookup_input = _inner_trace.get("data", {}).get(
"knowledgeBaseLookupInput"
)
if knowledge_base_lookup_input is not None:
knowledge_base_id = knowledge_base_lookup_input.get(
"knowledgeBaseId"
)
trace_name = f"KNOWLEDGE_BASE_{knowledge_base_id}"
_create_trace_by_type(
trace_name, _trace_id, _trace, optional_rationale_subtrace
)
_extract_knowledge_base_trace(
_trace_id, trace_group, knowledge_base_lookup_input
)
trace_group.remove(_trace)
return trace_name

for _trace_id, _trace_group in trace_id_groups_copy.items():
trace_group = sorted(_trace_group, key=lambda tg: tg["event_order"])
model_invocation_input_subtrace = None
optional_rationale_subtrace = None
for _trace in _trace_group:
if model_invocation_input_key == _trace.get("type", ""):
model_invocation_input_subtrace = _trace
elif "rationale" == _trace.get("type", ""):
optional_rationale_subtrace = _trace
_trace_group_type(
_trace_id,
trace_group,
model_invocation_input_subtrace,
optional_rationale_subtrace,
)

final_response = (
list(trace_id_groups_copy.values())[-1][-1]
.get("data", {})
.get("finalResponse", {})
.get("text")
)
return final_response

@mlflow.trace(name="Bedrock Input Prompt")
def _get_agent_prompt(self, raw_input_question):
"""
Constructs the agent prompt by combining the input question and the agent instruction.

Args:
raw_input_question (str): The user's input question.

Returns:
str: The formatted agent prompt.
"""
return f"""
Answer the following question and pay strong attention to the prompt:
<question>
{raw_input_question}
</question>
<instruction>
{self._agent_instruction}
</instruction>
"""

@mlflow.trace(name="bedrock-agent", span_type=SpanType.CHAT_MODEL)
def predict(
self, context, messages: List[ChatMessage], params: Optional[ChatParams]
) -> ChatResponse:
"""
Makes a prediction using the Bedrock agent and processes the response.

Args:
context: The MLflow context.
messages (List[ChatMessage]): A list of chat messages.
params (Optional[ChatParams]): Optional parameters for the chat.

Returns:
ChatResponse: The response from the Bedrock agent.
"""
formatted_input = messages[-1].content
session_id = uuid.uuid4().hex

response = self.brt.invoke_agent(
agentId=self._bedrock_agent_id,
agentAliasId=self._bedrock_agent_alias_id,
inputText=self._get_agent_prompt(formatted_input),
enableTrace=True,
sessionId=session_id,
endSession=False,
)

# Since this provider's output doesn't match the OpenAI specification,
# we need to go through the returned trace data and map it appropriately
# to create the MLflow span object.
events = []
for index, event in enumerate(response.get("completion", [])):
if "trace" in event:
event["trace"]["event_order"] = index
events.append(event)
trace_id_groups = self._extract_trace_groups(events)
final_response = self._get_final_response_with_trace(trace_id_groups)
with mlflow.start_span(
name="retrieved-response", span_type=SpanType.AGENT
) as span:
span.set_inputs(messages)
span.set_attributes({})

output = ChatResponse(
choices=[
ChatChoice(
index=0,
message=ChatMessage(role="user", content=final_response),
)
],
usage={},
model=self._model,
)

span.set_outputs(output)

return output

Here are some important remarks about this BedrockModel implementation:

  • AWS access key ID, secret key and the session token are externalized here. These need to be present in the environment before we can run inference. You will need to generate it for your IAM user and set them as environment variables.
aws sts get-session-token --duration-seconds 3600

And then set the following:

import os

os.environ['AWS_ACCESS_KEY'] = "<AccessKeyId>"
os.environ['AWS_SECRET_ACCESS_KEY'] = "<SecretAccessKey>"
os.environ['AWS_SESSION_TOKEN'] = "<SessionToken>"

As noticed in the code above, these do not get logged with the model and are only set inside load_context. This method is called when ChatModel is constructed. Further details are here

  • Bedrock agent ID and agent alias ID are passed via model_config that we will use below.

  • boto3 module has been excluded from getting pickled. This is done via __getstate__ and __setstate__ where we exclude it and reset it respectively

Log and load the BedrockModel

import mlflow
from mlflow.models import infer_signature

input_example = [
{
"messages": [
{
"role": "user",
"content": "When is the next launch window for Mars?",
}
]
}
]

output_example = {
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {"role": "assistant", "content": "test content"},
}
]
}
signature = infer_signature(input_example, output_example)

with mlflow.start_run():

model_config = {
"agents": {
"main": {
"model": "anthropic.claude-v2",
"aws_region": "us-east-1",
"bedrock_agent_id": "O9KQSEVEFF",
"bedrock_agent_alias_id": "3WHEEJKNUT",
"instruction": (
"You have functions available at your disposal to use when anwering any questions about orbital mechanics."
"if you can't find a function to answer a question about orbital mechanics, simply reply "
"'I do not know'"
),
"inference_configuration": {
"temperature": 0.5,
"maximumLength": 2000,
},
},
},
}

# Input example for the model
input_example = {
"messages": [
{
"role": "user",
"content": "When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.",
}
]
}

# Log and load the model using MLflow
logged_chain_info = mlflow.pyfunc.log_model(
python_model=BedrockModel(),
model_config=model_config,
artifact_path="chain", # This string is used as the path inside the MLflow model where artifacts are stored
input_example=input_example, # Must be a valid input to your chain
)

loaded = mlflow.pyfunc.load_model(logged_chain_info.model_uri)

# Predict using the loaded model
response = loaded.predict(
{
"messages": [
{
"role": "user",
"content": "When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.",
}
]
}
)
print(response)

Mapping Bedrock Agent Trace Data to MLflow Span Objects

In this step, we need to iterate over the data that is returned within the bedrock agent's response trace to provide relevant mappings to create the MLflow span object. AWS Bedrock agent's response is a flat list with trace events connected by traceId. Here is the raw trace sent in the bedrock agent's response:

Expand to see AWS Bedrock agent's raw trace
[
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 0,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'preProcessingTrace': {
'modelInvocationInput': {
'inferenceConfiguration': {
...
},
'text': '\n\nHuman: You are a classifying agent that filters user inputs into categories. Your job is to sort these inputs before they...<thinking> XML tags before providing only the category letter to sort the input into within <category> XML tags.\n\nAssistant:',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-pre-0',
'type': 'PRE_PROCESSING'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 1,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'preProcessingTrace': {
'modelInvocationOutput': {
'parsedResponse': {
...
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-pre-0'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 2,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'modelInvocationInput': {
'inferenceConfiguration': {
...
},
'text': '\n\nHuman:\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>...\n\nAssistant: <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\n\n',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0',
'type': 'ORCHESTRATION'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 3,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'modelInvocationOutput': {
'metadata': {
...
},
'rawResponse': {
...
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 4,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'rationale': {
'text': 'To answer this question about the next Mars launch window, I will:\n\n1. Call the GET::optimal_departure_window_mars::getNext...lse values.\n\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 5,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'invocationInput': {
'actionGroupInvocationInput': {
...
},
'invocationType': 'ACTION_GROUP',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 6,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'observation': {
'actionGroupInvocationOutput': {
...
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0',
'type': 'ACTION_GROUP'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 7,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'modelInvocationInput': {
'inferenceConfiguration': {
...
},
'text': '\n\nHuman:\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>...lta_v_available_m_s": 39457.985759929674, "delta_v_required_m_s": 5595.997417810693, "is_feasible": true}}</function_result>\n',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1',
'type': 'ORCHESTRATION'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 8,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'modelInvocationOutput': {
'metadata': {
...
},
'rawResponse': {
...
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1'
}
}
}
}
},
{
'trace': {
'agentAliasId': '3WHEEJKNUT',
'agentId': 'O9KQSEVEFF',
'agentVersion': '1',
'event_order': 9,
'sessionId': '9566a6d78551434fb0409578ffed63c1',
'trace': {
'orchestrationTrace': {
'observation': {
'finalResponse': {
...
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1',
'type': 'FINISH'
}
}
}
}
},
{
'chunk': {
'bytes': b
'Based on the provided spacecraft dry mass of 10000 kg, total mass of 50000 kg, and specific impulse of 2500 s, the next optimal launch window for a Hohmann transfer from Earth to Mars is on November 26, 2026 UTC. The transfer will take 259 days.'
}
}
]

To fit this structure into MLflow's span, we first need to go through the raw response trace and group events by their traceId. After grouping the trace events by traceId, the structure looks like this:

Expand to see trace grouped by traceId
{
'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0': [
{
'data': {
'inferenceConfiguration': {
'maximumLength': 2048,
'stopSequences': [
'</function_call>',
'</answer>',
'</error>'
],
'temperature': 0.0,
'topK': 250,
'topP': 1.0
},
'text': '\n\nHuman:\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>...\n\nAssistant: <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\n\n',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0',
'type': 'ORCHESTRATION'
},
'event_order': 2,
'type': 'modelInvocationInput'
},
{
'data': {
'metadata': {
'usage': {
'inputTokens': 5160,
'outputTokens': 135
}
},
'rawResponse': {
'content': 'To answer this question about the next Mars launch window, I will:\n\n1. Call the GET::optimal_departure_window_mars::getNext...l>\nGET::optimal_departure_window_mars::getNextMarsLaunchWindow(specific_impulse="2500", dry_mass="10000", total_mass="50000")'
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0'
},
'event_order': 3,
'type': 'modelInvocationOutput'
},
{
'data': {
'text': 'To answer this question about the next Mars launch window, I will:\n\n1. Call the GET::optimal_departure_window_mars::getNext...lse values.\n\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0'
},
'event_order': 4,
'type': 'rationale'
},
{
'data': {
'actionGroupInvocationInput': {
'actionGroupName': 'optimal_departure_window_mars',
'apiPath': '/get-next-mars-launch-window',
'executionType': 'LAMBDA',
'parameters': [
{
...
},
{
...
},
{
...
}
],
'verb': 'get'
},
'invocationType': 'ACTION_GROUP',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0'
},
'event_order': 5,
'type': 'invocationInput'
},
{
'data': {
'actionGroupInvocationOutput': {
'text': '{"next_launch_window": {"next_launch_date": "2026-11-26 00:00:00", "synodic_period_days": 779.9068939794238, "transfer_time_days": 259, "delta_v_available_m_s": 39457.985759929674, "delta_v_required_m_s": 5595.997417810693, "is_feasible": true}}'
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-0',
'type': 'ACTION_GROUP'
},
'event_order': 6,
'type': 'observation'
}
],
'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1': [
{
'data': {
'inferenceConfiguration': {
'maximumLength': 2048,
'stopSequences': [
'</function_call>',
'</answer>',
'</error>'
],
'temperature': 0.0,
'topK': 250,
'topP': 1.0
},
'text': '\n\nHuman:\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>...lta_v_available_m_s": 39457.985759929674, "delta_v_required_m_s": 5595.997417810693, "is_feasible": true}}</function_result>\n',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1',
'type': 'ORCHESTRATION'
},
'event_order': 7,
'type': 'modelInvocationInput'
},
{
'data': {
'metadata': {
'usage': {
'inputTokens': 5405,
'outputTokens': 64
}
},
'rawResponse': {
'content': '<answer>\nBased on the provided spacecraft dry mass of 10000 kg, total mass of 50000 kg, and specific impulse of 2500 s, the ... optimal launch window for a Hohmann transfer from Earth to Mars is on November 26, 2026 UTC. The transfer will take 259 days.'
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1'
},
'event_order': 8,
'type': 'modelInvocationOutput'
},
{
'data': {
'finalResponse': {
'text': 'Based on the provided spacecraft dry mass of 10000 kg, total mass of 50000 kg, and specific impulse of 2500 s, the next optimal launch window for a Hohmann transfer from Earth to Mars is on November 26, 2026 UTC. The transfer will take 259 days.'
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-1',
'type': 'FINISH'
},
'event_order': 9,
'type': 'observation'
}
],
'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-pre-0': [
{
'data': {
'inferenceConfiguration': {
'maximumLength': 2048,
'stopSequences': [
'\n\nHuman:'
],
'temperature': 0.0,
'topK': 250,
'topP': 1.0
},
'text': '\n\nHuman: You are a classifying agent that filters user inputs into categories. Your job is to sort these inputs before they...<thinking> XML tags before providing only the category letter to sort the input into within <category> XML tags.\n\nAssistant:',
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-pre-0',
'type': 'PRE_PROCESSING'
},
'event_order': 0,
'type': 'modelInvocationInput'
},
{
'data': {
'parsedResponse': {
'isValid': True,
'rationale': 'Based on the provided instructions, this input appears to be a question about orbital mechanics that can be answered using th...equired arguments for that function - specific impulse, dry mass, and total mass. Therefore, this input should be sorted into:'
},
'traceId': 'ca9880a2-dae7-46ac-a480-f38ca7e2d99f-pre-0'
},
'event_order': 1,
'type': 'modelInvocationOutput'
}
]
}

Each group of events with the same traceId will contain at least two events: one of type modelInvocationInput and one of type modelInvocationOutput. Groups that involve action group traces will also include events of type actionGroupInvocationInput and actionGroupInvocationOutput. Similarly, groups that use knowledge bases will have additional events of type knowledgeBaseLookupInput and knowledgeBaseLookupOutput. In the BedrockModel mentioned above, it implements an approach to parse these event groups into trace nodes. This method allows the trace to display the reasoning behind selecting action groups/knowledge bases to answer queries and invoking the corresponding Lambda function calls, as defined in our example OpenAPI spec above. This structure helps to clearly show the flow of information and decision-making process that bedrock agent follows.

Here is the final mlflow trace
{
"spans": [
{
"name": "Bedrock Agent Runtime",
"context": {
"span_id": "0xb802165d133a33aa",
"trace_id": "0x9b8bd0b2e018d77f936e48a09e54fd44"
},
"parent_id": null,
"start_time": 1731388531754725000,
"end_time": 1731388550226771000,
"status_code": "OK",
"status_message": "",
"attributes": {
"mlflow.traceRequestId": "\"1e036cc3a7f946ec995f7763b8dde51c\"",
"mlflow.spanType": "\"CHAT_MODEL\"",
"mlflow.spanFunctionName": "\"predict\"",
"mlflow.spanInputs": "{\"context\": \"<mlflow.pyfunc.model.PythonModelContext object at 0x13397c530>\", \"messages\": [{\"role\": \"user\", \"content\": \"When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\", \"name\": null}], \"params\": {\"temperature\": 1.0, \"max_tokens\": null, \"stop\": null, \"n\": 1, \"stream\": false, \"top_p\": null, \"top_k\": null, \"frequency_penalty\": null, \"presence_penalty\": null}}",
"mlflow.spanOutputs": "{\"choices\": [{\"index\": 0, \"message\": {\"role\": \"user\", \"content\": \"Based on the provided spacecraft dry mass of 10000 kg, total mass of 50000 kg, and specific impulse of 2500 s, the next optimal launch window for a Hohmann transfer from Earth to Mars is on November 26, 2026 UTC. The transfer will take 259 days.\", \"name\": null}, \"finish_reason\": \"stop\", \"logprobs\": null}], \"usage\": {\"prompt_tokens\": null, \"completion_tokens\": null, \"total_tokens\": null}, \"id\": null, \"model\": \"anthropic.claude-v2\", \"object\": \"chat.completion\", \"created\": 1731388550}"
},
"events": []
},
{
"name": "Bedrock Input Prompt",
"context": {
"span_id": "0x2e7cd730be70865b",
"trace_id": "0x9b8bd0b2e018d77f936e48a09e54fd44"
},
"parent_id": "0xb802165d133a33aa",
"start_time": 1731388531755172000,
"end_time": 1731388531755252000,
"status_code": "OK",
"status_message": "",
"attributes": {
"mlflow.traceRequestId": "\"1e036cc3a7f946ec995f7763b8dde51c\"",
"mlflow.spanType": "\"UNKNOWN\"",
"mlflow.spanFunctionName": "\"_get_agent_prompt\"",
"mlflow.spanInputs": "{\"raw_input_question\": \"When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\"}",
"mlflow.spanOutputs": "\"\\n Answer the following question and pay strong attention to the prompt:\\n <question>\\n When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\\n </question>\\n <instruction>\\n You have functions available at your disposal to use when anwering any questions about orbital mechanics.if you can't find a function to answer a question about orbital mechanics, simply reply 'I do not know'\\n </instruction>\\n \""
},
"events": []
},
{
"name": "ACTION GROUP DECISION -optimal_departure_window_mars",
"context": {
"span_id": "0x131e4e08cd5e95d9",
"trace_id": "0x9b8bd0b2e018d77f936e48a09e54fd44"
},
"parent_id": "0xb802165d133a33aa",
"start_time": 1731388550223219000,
"end_time": 1731388550224592000,
"status_code": "OK",
"status_message": "",
"attributes": {
"mlflow.traceRequestId": "\"1e036cc3a7f946ec995f7763b8dde51c\"",
"mlflow.spanType": "\"UNKNOWN\"",
"trace_attributes": "[{\"type\": \"modelInvocationInput\", \"data\": {\"inferenceConfiguration\": {\"maximumLength\": 2048, \"stopSequences\": [\"</function_call>\", \"</answer>\", \"</error>\"], \"temperature\": 0.0, \"topK\": 250, \"topP\": 1.0}, \"text\": \"\\n\\nHuman:\\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>. Your goal is to answer the user's question to the best of your ability, using the function(s) to gather more information if necessary to better answer the question. If you choose to call a function, the result of the function call will be added to the conversation history in <function_results> tags (if the call succeeded) or <error> tags (if the function failed). \\nYou were created with these instructions to consider as well:\\n<auxiliary_instructions>\\n You are a friendly chat bot. You have access to a function called that returns\\n information about the Mars launch window. When responding with Mars launch window,\\n please make sure to add the timezone UTC.\\n </auxiliary_instructions>\\n\\nHere are some examples of correct action by other, different agents with access to functions that may or may not be similar to ones you are provided.\\n\\n<examples>\\n <example_docstring> Here is an example of how you would correctly answer a question using a <function_call> and the corresponding <function_result>. Notice that you are free to think before deciding to make a <function_call> in the <scratchpad>.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n\\n <question>Can you show me my policy engine violation from 1st january 2023 to 1st february 2023? My alias is jsmith.</question>\\n <scratchpad>\\n To answer this question, I will need to:\\n 1. I do not have knowledge to policy engine violations, so I should see if I can use any of the available functions to help. I have been equipped with get::policyengineactions::getpolicyviolations that gets the policy engine violations for a given alias, start date and end date. I will use this function to gather more information.\\n </scratchpad>\\n <function_call>get::policyengineactions::getpolicyviolations(alias=\\\"jsmith\\\", startDate=\\\"1st January 2023\\\", endDate=\\\"1st February 2023\\\")</function_call>\\n <function_result>{response: [{creationDate: \\\"2023-06-01T09:30:00Z\\\", riskLevel: \\\"High\\\", policyId: \\\"POL-001\\\", policyUrl: \\\"https://example.com/policies/POL-001\\\", referenceUrl: \\\"https://example.com/violations/POL-001\\\"}, {creationDate: \\\"2023-06-02T14:45:00Z\\\", riskLevel: \\\"Medium\\\", policyId: \\\"POL-002\\\", policyUrl: \\\"https://example.com/policies/POL-002\\\", referenceUrl: \\\"https://example.com/violations/POL-002\\\"}]}</function_result>\\n <answer>The policy engine violations between 1st january 2023 to 1st february 2023 for alias jsmith are - Policy ID: POL-001, Policy ID: POL-002</answer>\\n </example>\\n\\n <example_docstring>Here is another example that utilizes multiple function calls.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n <question>Can you check the policy engine violations under my manager between 2nd May to 5th May? My alias is john.</question>\\n <scratchpad>\\n To answer this question, I will need to:\\n 1. Get the manager alias of the user using get::activedirectoryactions::getmanager function.\\n 2. Use the returned manager alias to get the policy engine violations using the get::policyengineactions::getpolicyviolations function.\\n\\n I have double checked and made sure that I have been provided the get::activedirectoryactions::getmanager and the get::policyengineactions::getpolicyviolations functions.\\n </scratchpad>\\n <function_call>get::activedirectoryactions::getmanager(alias=\\\"john\\\")</function_call>\\n <function_result>{response: {managerAlias: \\\"mark\\\", managerLevel: \\\"6\\\", teamName: \\\"Builder\\\", managerName: \\\"Mark Hunter\\\"}}}}</function_result>\\n <scratchpad>\\n 1. I have the managerAlias from the function results as mark and I have the start and end date from the user input. I can use the function result to call get::policyengineactions::getpolicyviolations function.\\n 2. I will then return the get::policyengineactions::getpolicyviolations function result to the user.\\n\\n I have double checked and made sure that I have been provided the get::policyengineactions::getpolicyviolations functions.\\n </scratchpad>\\n <function_call>get::policyengineactions::getpolicyviolations(alias=\\\"mark\\\", startDate=\\\"2nd May 2023\\\", endDate=\\\"5th May 2023\\\")</function_call>\\n <function_result>{response: [{creationDate: \\\"2023-05-02T09:30:00Z\\\", riskLevel: \\\"High\\\", policyId: \\\"POL-001\\\", policyUrl: \\\"https://example.com/policies/POL-001\\\", referenceUrl: \\\"https://example.com/violations/POL-001\\\"}, {creationDate: \\\"2023-05-04T14:45:00Z\\\", riskLevel: \\\"Low\\\", policyId: \\\"POL-002\\\", policyUrl: \\\"https://example.com/policies/POL-002\\\", referenceUrl: \\\"https://example.com/violations/POL-002\\\"}]}</function_result>\\n <answer>\\n The policy engine violations between 2nd May 2023 to 5th May 2023 for your manager's alias mark are - Policy ID: POL-001, Policy ID: POL-002\\n </answer>\\n </example>\\n\\n <example_docstring>Functions can also be search engine API's that issue a query to a knowledge base. Here is an example that utilizes regular function calls in combination with function calls to a search engine API. Please make sure to extract the source for the information within the final answer when using information returned from the search engine.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::benefitsaction::getbenefitplanname</function_name>\\n <function_description>Get's the benefit plan name for a user. The API takes in a userName and a benefit type and returns the benefit name to the user (i.e. Aetna, Premera, Fidelity, etc.).</function_description>\\n <optional_argument>userName (string): None</optional_argument>\\n <optional_argument>benefitType (string): None</optional_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::benefitsaction::increase401klimit</function_name>\\n <function_description>Increases the 401k limit for a generic user. The API takes in only the current 401k limit and returns the new limit.</function_description>\\n <optional_argument>currentLimit (string): None</optional_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_dentalinsurance::search</function_name>\\n <function_description>This is a search tool that provides information about Delta Dental benefits. It has information about covered dental benefits and other relevant information</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_401kplan::search</function_name>\\n <function_description>This is a search tool that provides information about Amazon 401k plan benefits. It can determine what a person's yearly 401k contribution limit is, based on their age.</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_healthinsurance::search</function_name>\\n <function_description>This is a search tool that provides information about Aetna and Premera health benefits. It has information about the savings plan and shared deductible plan, as well as others.</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n \\n </functions>\\n\\n <question>What is my deductible? My username is Bob and my benefitType is Dental. Also, what is the 401k yearly contribution limit?</question>\\n <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\\n To answer this question, I will:\\n 1. Call the get::benefitsaction::getbenefitplanname function to get the benefit plan name for the user Bob with benefit type Dental.\\n 2. Call the get::x_amz_knowledgebase_dentalinsurance::search function to search for information about deductibles for the plan name returned from step 1.\\n 3. Call the get::x_amz_knowledgebase_401k::search function to search for information about 401k yearly contribution limits.\\n 4. Return the deductible information from the search results to the user.\\n I have checked that I have access to the get::benefitsaction::getbenefitplanname, x_amz_knowledgebase_dentalinsurance::search, and x_amz_knowledgebase_401k::search functions.\\n </scratchpad>\\n <function_call>get::benefitsaction::getbenefitplanname(userName=\\\"Bob\\\", benefitType=\\\"Dental\\\")</function_call>\\n <function_result>{{'response': {{'planName': 'Delta Dental'}}}}</function_result>\\n <scratchpad>\\n I have received the plan name Delta Dental for the user Bob with Dental benefits. I will now call the x_amz_knowledgebase_dentalinsurance::search function to find deductible information for Delta Dental.\\n </scratchpad>\\n <function_call>get::x_amz_knowledgebase_dentalinsurance::search(searchQuery=\\\"What is the deductible for Delta Dental?\\\")</function_call>\\n <function_result>{{'response': {{'responseCode': '200', 'responseBody': \\\"\\\"<answer>\\\\n<answer_part>\\\\n<text>The user's individual deductible is $50 per benefit period</text>\\\\n<source>dfe040f8-46ed-4a65-b3ea-529fa55f6b9e</source>\\\\n</answer_part>\\\\n<answer_part>\\\\n<text>If they are enrolled with dependents, the maximum family deductible is $150 per benefit period.</text>\\\\n<source>0e666064-31d8-4223-b7ba-8eecf40b7b47</source>\\\\n</answer_part>\\\\n</answer>\\\"}}}}</function_result> <scratchpad>\\n I have found the deductible information for Dental benefits. I will now call the x_amz_knowledgebase_401k::search function to find yearly 401k contribution limits.\\n </scratchpad>\\n <function_call>get::x_amz_knowledgebase_401k::search(searchQuery=\\\"What is the yearly 401k contribution limit?\\\")</function_call>\\n <function_result>{{'response': {{'responseCode': '200', 'responseBody': \\\"<answer>\\\\n<answer_part>\\\\n<text>The yearly 401k contribution limit is $20,500.</text>\\\\n<source>c546cbe8-07f6-45d1-90ca-74d87ab2885a</source>\\\\n</answer_part>\\\\n</answer>\\\"}}}}</function_result>\\n <answer>\\n <answer_part>\\n <text>The deductible for your Delta Dental plan is $50 per benefit period.</text>\\n <source>dfe040f8-46ed-4a65-b3ea-529fa55f6b9e</source>\\n </answer_part>\\n <answer_part>\\n <text>If you have dependents enrolled, the maximum family deductible is $150 per benefit period.</text>\\n <source>0e666064-31d8-4223-b7ba-8eecf40b7b47</source>\\n </answer_part>\\n <answer_part>\\n <text>The yearly 401k contribution limit is $20,500.</text>\\n <source>c546cbe8-07f6-45d1-90ca-74d87ab2885a</source>\\n </answer_part>\\n </answer>\\n </example>\\n\\n \\n\\n <example_docstring>Here's a final example where the question asked could not be answered with information gathered from calling the provided functions. In this example, notice how you respond by telling the user you cannot answer, without using a function that was not provided to you.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n <question>Who are the reportees of David?</question>\\n <scratchpad>\\n After reviewing the functions I was equipped with, I realize I am not able to accurately answer this question since I can't access reportees of David. Therefore, I should explain to the user I cannot answer this question.\\n </scratchpad>\\n <answer>\\n Sorry, I am unable to assist you with this request.\\n </answer>\\n </example>\\n</examples>\\n\\nThe above examples have been provided to you to illustrate general guidelines and format for use of function calling for information retrieval, and how to use your scratchpad to plan your approach. IMPORTANT: the functions provided within the examples should not be assumed to have been provided to you to use UNLESS they are also explicitly given to you within <functions></functions> tags below. All of the values and information within the examples (the questions, function results, and answers) are strictly part of the examples and have not been provided to you.\\n\\nNow that you have read and understood the examples, I will define the functions that you have available to you to use. Here is a comprehensive list.\\n\\n<functions>\\n<function>\\n<function_name>GET::optimal_departure_window_mars::getNextMarsLaunchWindow</function_name>\\n<function_description>Gets the next optimal launch window to Mars.</function_description>\\n<required_argument>specific_impulse (string): Specific impulse of the propulsion system (s).</required_argument>\\n<required_argument>dry_mass (string): Mass of the spacecraft without fuel (kg).</required_argument>\\n<required_argument>total_mass (string): Total mass of the spacecraft including fuel (kg)</required_argument>\\n<returns>object: The next optimal departure date for a Hohmann transfer from Earth to Mars, based on the spacecraft's mass and specific impulse.</returns>\\n</function>\\n\\n\\n</functions>\\n\\nNote that the function arguments have been listed in the order that they should be passed into the function.\\n\\n\\n\\nDo not modify or extend the provided functions under any circumstances. For example, GET::optimal_departure_window_mars::getNextMarsLaunchWindow with additional parameters would be considered modifying the function which is not allowed. Please use the functions only as defined.\\n\\nDO NOT use any functions that I have not equipped you with.\\n\\n Do not make assumptions about inputs; instead, make sure you know the exact function and input to use before you call a function.\\n\\nTo call a function, output the name of the function in between <function_call> and </function_call> tags. You will receive a <function_result> in response to your call that contains information that you can use to better answer the question. Or, if the function call produced an error, you will receive an <error> in response.\\n\\n\\n\\nThe format for all other <function_call> MUST be: <function_call>$FUNCTION_NAME($FUNCTION_PARAMETER_NAME=$FUNCTION_PARAMETER_VALUE)</function_call>\\n\\nRemember, your goal is to answer the user's question to the best of your ability, using only the function(s) provided within the <functions></functions> tags to gather more information if necessary to better answer the question.\\n\\nDo not modify or extend the provided functions under any circumstances. For example, calling GET::optimal_departure_window_mars::getNextMarsLaunchWindow with additional parameters would be modifying the function which is not allowed. Please use the functions only as defined.\\n\\nBefore calling any functions, create a plan for performing actions to answer this question within the <scratchpad>. Double check your plan to make sure you don't call any functions that you haven't been provided with. Always return your final answer within <answer></answer> tags.\\n\\n\\n\\nThe user input is <question>Answer the following question and pay strong attention to the prompt:\\n <question>\\n When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\\n </question>\\n <instruction>\\n You have functions available at your disposal to use when anwering any questions about orbital mechanics.if you can't find a function to answer a question about orbital mechanics, simply reply 'I do not know'\\n </instruction></question>\\n\\n\\nAssistant: <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\\n\\n\", \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\", \"type\": \"ORCHESTRATION\"}, \"event_order\": 2}, {\"type\": \"modelInvocationOutput\", \"data\": {\"metadata\": {\"usage\": {\"inputTokens\": 5160, \"outputTokens\": 135}}, \"rawResponse\": {\"content\": \"To answer this question about the next Mars launch window, I will:\\n\\n1. Call the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function to get the next optimal launch window, passing in the provided spacecraft mass and specific impulse values.\\n\\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.\\n\\n</scratchpad>\\n\\n<function_call>\\nGET::optimal_departure_window_mars::getNextMarsLaunchWindow(specific_impulse=\\\"2500\\\", dry_mass=\\\"10000\\\", total_mass=\\\"50000\\\")\"}, \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\"}, \"event_order\": 3}, {\"type\": \"rationale\", \"data\": {\"text\": \"To answer this question about the next Mars launch window, I will:\\n\\n1. Call the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function to get the next optimal launch window, passing in the provided spacecraft mass and specific impulse values.\\n\\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.\", \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\"}, \"event_order\": 4}, {\"type\": \"invocationInput\", \"data\": {\"actionGroupInvocationInput\": {\"actionGroupName\": \"optimal_departure_window_mars\", \"apiPath\": \"/get-next-mars-launch-window\", \"executionType\": \"LAMBDA\", \"parameters\": [{\"name\": \"total_mass\", \"type\": \"string\", \"value\": \"50000\"}, {\"name\": \"dry_mass\", \"type\": \"string\", \"value\": \"10000\"}, {\"name\": \"specific_impulse\", \"type\": \"string\", \"value\": \"2500\"}], \"verb\": \"get\"}, \"invocationType\": \"ACTION_GROUP\", \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\"}, \"event_order\": 5}, {\"type\": \"observation\", \"data\": {\"actionGroupInvocationOutput\": {\"text\": \"{\\\"next_launch_window\\\": {\\\"next_launch_date\\\": \\\"2026-11-26 00:00:00\\\", \\\"synodic_period_days\\\": 779.9068939794238, \\\"transfer_time_days\\\": 259, \\\"delta_v_available_m_s\\\": 39457.985759929674, \\\"delta_v_required_m_s\\\": 5595.997417810693, \\\"is_feasible\\\": true}}\"}, \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\", \"type\": \"ACTION_GROUP\"}, \"event_order\": 6}]",
"mlflow.spanFunctionName": "\"_trace_agent_pre_context\"",
"mlflow.spanInputs": "{\"inner_input_trace\": \"\\n\\nHuman:\\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>. Your goal is to answer the user's question to the best of your ability, using the function(s) to gather more information if necessary to better answer the question. If you choose to call a function, the result of the function call will be added to the conversation history in <function_results> tags (if the call succeeded) or <error> tags (if the function failed). \\nYou were created with these instructions to consider as well:\\n<auxiliary_instructions>\\n You are a friendly chat bot. You have access to a function called that returns\\n information about the Mars launch window. When responding with Mars launch window,\\n please make sure to add the timezone UTC.\\n </auxiliary_instructions>\\n\\nHere are some examples of correct action by other, different agents with access to functions that may or may not be similar to ones you are provided.\\n\\n<examples>\\n <example_docstring> Here is an example of how you would correctly answer a question using a <function_call> and the corresponding <function_result>. Notice that you are free to think before deciding to make a <function_call> in the <scratchpad>.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n\\n <question>Can you show me my policy engine violation from 1st january 2023 to 1st february 2023? My alias is jsmith.</question>\\n <scratchpad>\\n To answer this question, I will need to:\\n 1. I do not have knowledge to policy engine violations, so I should see if I can use any of the available functions to help. I have been equipped with get::policyengineactions::getpolicyviolations that gets the policy engine violations for a given alias, start date and end date. I will use this function to gather more information.\\n </scratchpad>\\n <function_call>get::policyengineactions::getpolicyviolations(alias=\\\"jsmith\\\", startDate=\\\"1st January 2023\\\", endDate=\\\"1st February 2023\\\")</function_call>\\n <function_result>{response: [{creationDate: \\\"2023-06-01T09:30:00Z\\\", riskLevel: \\\"High\\\", policyId: \\\"POL-001\\\", policyUrl: \\\"https://example.com/policies/POL-001\\\", referenceUrl: \\\"https://example.com/violations/POL-001\\\"}, {creationDate: \\\"2023-06-02T14:45:00Z\\\", riskLevel: \\\"Medium\\\", policyId: \\\"POL-002\\\", policyUrl: \\\"https://example.com/policies/POL-002\\\", referenceUrl: \\\"https://example.com/violations/POL-002\\\"}]}</function_result>\\n <answer>The policy engine violations between 1st january 2023 to 1st february 2023 for alias jsmith are - Policy ID: POL-001, Policy ID: POL-002</answer>\\n </example>\\n\\n <example_docstring>Here is another example that utilizes multiple function calls.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n <question>Can you check the policy engine violations under my manager between 2nd May to 5th May? My alias is john.</question>\\n <scratchpad>\\n To answer this question, I will need to:\\n 1. Get the manager alias of the user using get::activedirectoryactions::getmanager function.\\n 2. Use the returned manager alias to get the policy engine violations using the get::policyengineactions::getpolicyviolations function.\\n\\n I have double checked and made sure that I have been provided the get::activedirectoryactions::getmanager and the get::policyengineactions::getpolicyviolations functions.\\n </scratchpad>\\n <function_call>get::activedirectoryactions::getmanager(alias=\\\"john\\\")</function_call>\\n <function_result>{response: {managerAlias: \\\"mark\\\", managerLevel: \\\"6\\\", teamName: \\\"Builder\\\", managerName: \\\"Mark Hunter\\\"}}}}</function_result>\\n <scratchpad>\\n 1. I have the managerAlias from the function results as mark and I have the start and end date from the user input. I can use the function result to call get::policyengineactions::getpolicyviolations function.\\n 2. I will then return the get::policyengineactions::getpolicyviolations function result to the user.\\n\\n I have double checked and made sure that I have been provided the get::policyengineactions::getpolicyviolations functions.\\n </scratchpad>\\n <function_call>get::policyengineactions::getpolicyviolations(alias=\\\"mark\\\", startDate=\\\"2nd May 2023\\\", endDate=\\\"5th May 2023\\\")</function_call>\\n <function_result>{response: [{creationDate: \\\"2023-05-02T09:30:00Z\\\", riskLevel: \\\"High\\\", policyId: \\\"POL-001\\\", policyUrl: \\\"https://example.com/policies/POL-001\\\", referenceUrl: \\\"https://example.com/violations/POL-001\\\"}, {creationDate: \\\"2023-05-04T14:45:00Z\\\", riskLevel: \\\"Low\\\", policyId: \\\"POL-002\\\", policyUrl: \\\"https://example.com/policies/POL-002\\\", referenceUrl: \\\"https://example.com/violations/POL-002\\\"}]}</function_result>\\n <answer>\\n The policy engine violations between 2nd May 2023 to 5th May 2023 for your manager's alias mark are - Policy ID: POL-001, Policy ID: POL-002\\n </answer>\\n </example>\\n\\n <example_docstring>Functions can also be search engine API's that issue a query to a knowledge base. Here is an example that utilizes regular function calls in combination with function calls to a search engine API. Please make sure to extract the source for the information within the final answer when using information returned from the search engine.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::benefitsaction::getbenefitplanname</function_name>\\n <function_description>Get's the benefit plan name for a user. The API takes in a userName and a benefit type and returns the benefit name to the user (i.e. Aetna, Premera, Fidelity, etc.).</function_description>\\n <optional_argument>userName (string): None</optional_argument>\\n <optional_argument>benefitType (string): None</optional_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::benefitsaction::increase401klimit</function_name>\\n <function_description>Increases the 401k limit for a generic user. The API takes in only the current 401k limit and returns the new limit.</function_description>\\n <optional_argument>currentLimit (string): None</optional_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_dentalinsurance::search</function_name>\\n <function_description>This is a search tool that provides information about Delta Dental benefits. It has information about covered dental benefits and other relevant information</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_401kplan::search</function_name>\\n <function_description>This is a search tool that provides information about Amazon 401k plan benefits. It can determine what a person's yearly 401k contribution limit is, based on their age.</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_healthinsurance::search</function_name>\\n <function_description>This is a search tool that provides information about Aetna and Premera health benefits. It has information about the savings plan and shared deductible plan, as well as others.</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n \\n </functions>\\n\\n <question>What is my deductible? My username is Bob and my benefitType is Dental. Also, what is the 401k yearly contribution limit?</question>\\n <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\\n To answer this question, I will:\\n 1. Call the get::benefitsaction::getbenefitplanname function to get the benefit plan name for the user Bob with benefit type Dental.\\n 2. Call the get::x_amz_knowledgebase_dentalinsurance::search function to search for information about deductibles for the plan name returned from step 1.\\n 3. Call the get::x_amz_knowledgebase_401k::search function to search for information about 401k yearly contribution limits.\\n 4. Return the deductible information from the search results to the user.\\n I have checked that I have access to the get::benefitsaction::getbenefitplanname, x_amz_knowledgebase_dentalinsurance::search, and x_amz_knowledgebase_401k::search functions.\\n </scratchpad>\\n <function_call>get::benefitsaction::getbenefitplanname(userName=\\\"Bob\\\", benefitType=\\\"Dental\\\")</function_call>\\n <function_result>{{'response': {{'planName': 'Delta Dental'}}}}</function_result>\\n <scratchpad>\\n I have received the plan name Delta Dental for the user Bob with Dental benefits. I will now call the x_amz_knowledgebase_dentalinsurance::search function to find deductible information for Delta Dental.\\n </scratchpad>\\n <function_call>get::x_amz_knowledgebase_dentalinsurance::search(searchQuery=\\\"What is the deductible for Delta Dental?\\\")</function_call>\\n <function_result>{{'response': {{'responseCode': '200', 'responseBody': \\\"\\\"<answer>\\\\n<answer_part>\\\\n<text>The user's individual deductible is $50 per benefit period</text>\\\\n<source>dfe040f8-46ed-4a65-b3ea-529fa55f6b9e</source>\\\\n</answer_part>\\\\n<answer_part>\\\\n<text>If they are enrolled with dependents, the maximum family deductible is $150 per benefit period.</text>\\\\n<source>0e666064-31d8-4223-b7ba-8eecf40b7b47</source>\\\\n</answer_part>\\\\n</answer>\\\"}}}}</function_result> <scratchpad>\\n I have found the deductible information for Dental benefits. I will now call the x_amz_knowledgebase_401k::search function to find yearly 401k contribution limits.\\n </scratchpad>\\n <function_call>get::x_amz_knowledgebase_401k::search(searchQuery=\\\"What is the yearly 401k contribution limit?\\\")</function_call>\\n <function_result>{{'response': {{'responseCode': '200', 'responseBody': \\\"<answer>\\\\n<answer_part>\\\\n<text>The yearly 401k contribution limit is $20,500.</text>\\\\n<source>c546cbe8-07f6-45d1-90ca-74d87ab2885a</source>\\\\n</answer_part>\\\\n</answer>\\\"}}}}</function_result>\\n <answer>\\n <answer_part>\\n <text>The deductible for your Delta Dental plan is $50 per benefit period.</text>\\n <source>dfe040f8-46ed-4a65-b3ea-529fa55f6b9e</source>\\n </answer_part>\\n <answer_part>\\n <text>If you have dependents enrolled, the maximum family deductible is $150 per benefit period.</text>\\n <source>0e666064-31d8-4223-b7ba-8eecf40b7b47</source>\\n </answer_part>\\n <answer_part>\\n <text>The yearly 401k contribution limit is $20,500.</text>\\n <source>c546cbe8-07f6-45d1-90ca-74d87ab2885a</source>\\n </answer_part>\\n </answer>\\n </example>\\n\\n \\n\\n <example_docstring>Here's a final example where the question asked could not be answered with information gathered from calling the provided functions. In this example, notice how you respond by telling the user you cannot answer, without using a function that was not provided to you.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n <question>Who are the reportees of David?</question>\\n <scratchpad>\\n After reviewing the functions I was equipped with, I realize I am not able to accurately answer this question since I can't access reportees of David. Therefore, I should explain to the user I cannot answer this question.\\n </scratchpad>\\n <answer>\\n Sorry, I am unable to assist you with this request.\\n </answer>\\n </example>\\n</examples>\\n\\nThe above examples have been provided to you to illustrate general guidelines and format for use of function calling for information retrieval, and how to use your scratchpad to plan your approach. IMPORTANT: the functions provided within the examples should not be assumed to have been provided to you to use UNLESS they are also explicitly given to you within <functions></functions> tags below. All of the values and information within the examples (the questions, function results, and answers) are strictly part of the examples and have not been provided to you.\\n\\nNow that you have read and understood the examples, I will define the functions that you have available to you to use. Here is a comprehensive list.\\n\\n<functions>\\n<function>\\n<function_name>GET::optimal_departure_window_mars::getNextMarsLaunchWindow</function_name>\\n<function_description>Gets the next optimal launch window to Mars.</function_description>\\n<required_argument>specific_impulse (string): Specific impulse of the propulsion system (s).</required_argument>\\n<required_argument>dry_mass (string): Mass of the spacecraft without fuel (kg).</required_argument>\\n<required_argument>total_mass (string): Total mass of the spacecraft including fuel (kg)</required_argument>\\n<returns>object: The next optimal departure date for a Hohmann transfer from Earth to Mars, based on the spacecraft's mass and specific impulse.</returns>\\n</function>\\n\\n\\n</functions>\\n\\nNote that the function arguments have been listed in the order that they should be passed into the function.\\n\\n\\n\\nDo not modify or extend the provided functions under any circumstances. For example, GET::optimal_departure_window_mars::getNextMarsLaunchWindow with additional parameters would be considered modifying the function which is not allowed. Please use the functions only as defined.\\n\\nDO NOT use any functions that I have not equipped you with.\\n\\n Do not make assumptions about inputs; instead, make sure you know the exact function and input to use before you call a function.\\n\\nTo call a function, output the name of the function in between <function_call> and </function_call> tags. You will receive a <function_result> in response to your call that contains information that you can use to better answer the question. Or, if the function call produced an error, you will receive an <error> in response.\\n\\n\\n\\nThe format for all other <function_call> MUST be: <function_call>$FUNCTION_NAME($FUNCTION_PARAMETER_NAME=$FUNCTION_PARAMETER_VALUE)</function_call>\\n\\nRemember, your goal is to answer the user's question to the best of your ability, using only the function(s) provided within the <functions></functions> tags to gather more information if necessary to better answer the question.\\n\\nDo not modify or extend the provided functions under any circumstances. For example, calling GET::optimal_departure_window_mars::getNextMarsLaunchWindow with additional parameters would be modifying the function which is not allowed. Please use the functions only as defined.\\n\\nBefore calling any functions, create a plan for performing actions to answer this question within the <scratchpad>. Double check your plan to make sure you don't call any functions that you haven't been provided with. Always return your final answer within <answer></answer> tags.\\n\\n\\n\\nThe user input is <question>Answer the following question and pay strong attention to the prompt:\\n <question>\\n When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\\n </question>\\n <instruction>\\n You have functions available at your disposal to use when anwering any questions about orbital mechanics.if you can't find a function to answer a question about orbital mechanics, simply reply 'I do not know'\\n </instruction></question>\\n\\n\\nAssistant: <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\\n\\n\"}",
"mlflow.spanOutputs": "\"To answer this question about the next Mars launch window, I will:\\n\\n1. Call the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function to get the next optimal launch window, passing in the provided spacecraft mass and specific impulse values.\\n\\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.\""
},
"events": []
},
{
"name": "Invoking Action Group",
"context": {
"span_id": "0x692bd6457647dc76",
"trace_id": "0x9b8bd0b2e018d77f936e48a09e54fd44"
},
"parent_id": "0xb802165d133a33aa",
"start_time": 1731388550224851000,
"end_time": 1731388550225218000,
"status_code": "OK",
"status_message": "",
"attributes": {
"mlflow.traceRequestId": "\"1e036cc3a7f946ec995f7763b8dde51c\"",
"mlflow.spanType": "\"UNKNOWN\"",
"trace_attributes": "[{\"type\": \"modelInvocationInput\", \"data\": {\"inferenceConfiguration\": {\"maximumLength\": 2048, \"stopSequences\": [\"</function_call>\", \"</answer>\", \"</error>\"], \"temperature\": 0.0, \"topK\": 250, \"topP\": 1.0}, \"text\": \"\\n\\nHuman:\\nYou are a research assistant AI that has been equipped with one or more functions to help you answer a <question>. Your goal is to answer the user's question to the best of your ability, using the function(s) to gather more information if necessary to better answer the question. If you choose to call a function, the result of the function call will be added to the conversation history in <function_results> tags (if the call succeeded) or <error> tags (if the function failed). \\nYou were created with these instructions to consider as well:\\n<auxiliary_instructions>\\n You are a friendly chat bot. You have access to a function called that returns\\n information about the Mars launch window. When responding with Mars launch window,\\n please make sure to add the timezone UTC.\\n </auxiliary_instructions>\\n\\nHere are some examples of correct action by other, different agents with access to functions that may or may not be similar to ones you are provided.\\n\\n<examples>\\n <example_docstring> Here is an example of how you would correctly answer a question using a <function_call> and the corresponding <function_result>. Notice that you are free to think before deciding to make a <function_call> in the <scratchpad>.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n\\n <question>Can you show me my policy engine violation from 1st january 2023 to 1st february 2023? My alias is jsmith.</question>\\n <scratchpad>\\n To answer this question, I will need to:\\n 1. I do not have knowledge to policy engine violations, so I should see if I can use any of the available functions to help. I have been equipped with get::policyengineactions::getpolicyviolations that gets the policy engine violations for a given alias, start date and end date. I will use this function to gather more information.\\n </scratchpad>\\n <function_call>get::policyengineactions::getpolicyviolations(alias=\\\"jsmith\\\", startDate=\\\"1st January 2023\\\", endDate=\\\"1st February 2023\\\")</function_call>\\n <function_result>{response: [{creationDate: \\\"2023-06-01T09:30:00Z\\\", riskLevel: \\\"High\\\", policyId: \\\"POL-001\\\", policyUrl: \\\"https://example.com/policies/POL-001\\\", referenceUrl: \\\"https://example.com/violations/POL-001\\\"}, {creationDate: \\\"2023-06-02T14:45:00Z\\\", riskLevel: \\\"Medium\\\", policyId: \\\"POL-002\\\", policyUrl: \\\"https://example.com/policies/POL-002\\\", referenceUrl: \\\"https://example.com/violations/POL-002\\\"}]}</function_result>\\n <answer>The policy engine violations between 1st january 2023 to 1st february 2023 for alias jsmith are - Policy ID: POL-001, Policy ID: POL-002</answer>\\n </example>\\n\\n <example_docstring>Here is another example that utilizes multiple function calls.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n <question>Can you check the policy engine violations under my manager between 2nd May to 5th May? My alias is john.</question>\\n <scratchpad>\\n To answer this question, I will need to:\\n 1. Get the manager alias of the user using get::activedirectoryactions::getmanager function.\\n 2. Use the returned manager alias to get the policy engine violations using the get::policyengineactions::getpolicyviolations function.\\n\\n I have double checked and made sure that I have been provided the get::activedirectoryactions::getmanager and the get::policyengineactions::getpolicyviolations functions.\\n </scratchpad>\\n <function_call>get::activedirectoryactions::getmanager(alias=\\\"john\\\")</function_call>\\n <function_result>{response: {managerAlias: \\\"mark\\\", managerLevel: \\\"6\\\", teamName: \\\"Builder\\\", managerName: \\\"Mark Hunter\\\"}}}}</function_result>\\n <scratchpad>\\n 1. I have the managerAlias from the function results as mark and I have the start and end date from the user input. I can use the function result to call get::policyengineactions::getpolicyviolations function.\\n 2. I will then return the get::policyengineactions::getpolicyviolations function result to the user.\\n\\n I have double checked and made sure that I have been provided the get::policyengineactions::getpolicyviolations functions.\\n </scratchpad>\\n <function_call>get::policyengineactions::getpolicyviolations(alias=\\\"mark\\\", startDate=\\\"2nd May 2023\\\", endDate=\\\"5th May 2023\\\")</function_call>\\n <function_result>{response: [{creationDate: \\\"2023-05-02T09:30:00Z\\\", riskLevel: \\\"High\\\", policyId: \\\"POL-001\\\", policyUrl: \\\"https://example.com/policies/POL-001\\\", referenceUrl: \\\"https://example.com/violations/POL-001\\\"}, {creationDate: \\\"2023-05-04T14:45:00Z\\\", riskLevel: \\\"Low\\\", policyId: \\\"POL-002\\\", policyUrl: \\\"https://example.com/policies/POL-002\\\", referenceUrl: \\\"https://example.com/violations/POL-002\\\"}]}</function_result>\\n <answer>\\n The policy engine violations between 2nd May 2023 to 5th May 2023 for your manager's alias mark are - Policy ID: POL-001, Policy ID: POL-002\\n </answer>\\n </example>\\n\\n <example_docstring>Functions can also be search engine API's that issue a query to a knowledge base. Here is an example that utilizes regular function calls in combination with function calls to a search engine API. Please make sure to extract the source for the information within the final answer when using information returned from the search engine.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::benefitsaction::getbenefitplanname</function_name>\\n <function_description>Get's the benefit plan name for a user. The API takes in a userName and a benefit type and returns the benefit name to the user (i.e. Aetna, Premera, Fidelity, etc.).</function_description>\\n <optional_argument>userName (string): None</optional_argument>\\n <optional_argument>benefitType (string): None</optional_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::benefitsaction::increase401klimit</function_name>\\n <function_description>Increases the 401k limit for a generic user. The API takes in only the current 401k limit and returns the new limit.</function_description>\\n <optional_argument>currentLimit (string): None</optional_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_dentalinsurance::search</function_name>\\n <function_description>This is a search tool that provides information about Delta Dental benefits. It has information about covered dental benefits and other relevant information</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_401kplan::search</function_name>\\n <function_description>This is a search tool that provides information about Amazon 401k plan benefits. It can determine what a person's yearly 401k contribution limit is, based on their age.</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n <function>\\n <function_name>get::x_amz_knowledgebase_healthinsurance::search</function_name>\\n <function_description>This is a search tool that provides information about Aetna and Premera health benefits. It has information about the savings plan and shared deductible plan, as well as others.</function_description>\\n <required_argument>query(string): A full sentence query that is fed to the search tool</required_argument>\\n <returns>Returns string related to the user query asked.</returns>\\n </function>\\n \\n </functions>\\n\\n <question>What is my deductible? My username is Bob and my benefitType is Dental. Also, what is the 401k yearly contribution limit?</question>\\n <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\\n To answer this question, I will:\\n 1. Call the get::benefitsaction::getbenefitplanname function to get the benefit plan name for the user Bob with benefit type Dental.\\n 2. Call the get::x_amz_knowledgebase_dentalinsurance::search function to search for information about deductibles for the plan name returned from step 1.\\n 3. Call the get::x_amz_knowledgebase_401k::search function to search for information about 401k yearly contribution limits.\\n 4. Return the deductible information from the search results to the user.\\n I have checked that I have access to the get::benefitsaction::getbenefitplanname, x_amz_knowledgebase_dentalinsurance::search, and x_amz_knowledgebase_401k::search functions.\\n </scratchpad>\\n <function_call>get::benefitsaction::getbenefitplanname(userName=\\\"Bob\\\", benefitType=\\\"Dental\\\")</function_call>\\n <function_result>{{'response': {{'planName': 'Delta Dental'}}}}</function_result>\\n <scratchpad>\\n I have received the plan name Delta Dental for the user Bob with Dental benefits. I will now call the x_amz_knowledgebase_dentalinsurance::search function to find deductible information for Delta Dental.\\n </scratchpad>\\n <function_call>get::x_amz_knowledgebase_dentalinsurance::search(searchQuery=\\\"What is the deductible for Delta Dental?\\\")</function_call>\\n <function_result>{{'response': {{'responseCode': '200', 'responseBody': \\\"\\\"<answer>\\\\n<answer_part>\\\\n<text>The user's individual deductible is $50 per benefit period</text>\\\\n<source>dfe040f8-46ed-4a65-b3ea-529fa55f6b9e</source>\\\\n</answer_part>\\\\n<answer_part>\\\\n<text>If they are enrolled with dependents, the maximum family deductible is $150 per benefit period.</text>\\\\n<source>0e666064-31d8-4223-b7ba-8eecf40b7b47</source>\\\\n</answer_part>\\\\n</answer>\\\"}}}}</function_result> <scratchpad>\\n I have found the deductible information for Dental benefits. I will now call the x_amz_knowledgebase_401k::search function to find yearly 401k contribution limits.\\n </scratchpad>\\n <function_call>get::x_amz_knowledgebase_401k::search(searchQuery=\\\"What is the yearly 401k contribution limit?\\\")</function_call>\\n <function_result>{{'response': {{'responseCode': '200', 'responseBody': \\\"<answer>\\\\n<answer_part>\\\\n<text>The yearly 401k contribution limit is $20,500.</text>\\\\n<source>c546cbe8-07f6-45d1-90ca-74d87ab2885a</source>\\\\n</answer_part>\\\\n</answer>\\\"}}}}</function_result>\\n <answer>\\n <answer_part>\\n <text>The deductible for your Delta Dental plan is $50 per benefit period.</text>\\n <source>dfe040f8-46ed-4a65-b3ea-529fa55f6b9e</source>\\n </answer_part>\\n <answer_part>\\n <text>If you have dependents enrolled, the maximum family deductible is $150 per benefit period.</text>\\n <source>0e666064-31d8-4223-b7ba-8eecf40b7b47</source>\\n </answer_part>\\n <answer_part>\\n <text>The yearly 401k contribution limit is $20,500.</text>\\n <source>c546cbe8-07f6-45d1-90ca-74d87ab2885a</source>\\n </answer_part>\\n </answer>\\n </example>\\n\\n \\n\\n <example_docstring>Here's a final example where the question asked could not be answered with information gathered from calling the provided functions. In this example, notice how you respond by telling the user you cannot answer, without using a function that was not provided to you.</example_docstring>\\n <example>\\n <functions>\\n <function>\\n <function_name>get::policyengineactions::getpolicyviolations</function_name>\\n <function_description>Returns a list of policy engine violations for the specified alias within the specified date range.</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <required_argument>startDate (string): The start date of the range to filter violations. The format for startDate is MM/DD/YYYY.</required_argument>\\n <required_argument>endDate (string): The end date of the range to filter violations</required_argument>\\n <returns>array: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>post::policyengineactions::acknowledgeviolations</function_name>\\n <function_description>Acknowledge policy engine violation. Generally used to acknowledge violation, once user notices a violation under their alias or their managers alias.</function_description>\\n <required_argument>policyId (string): The ID of the policy violation</required_argument>\\n <required_argument>expectedDateOfResolution (string): The date by when the violation will be addressed/resolved</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n <function>\\n <function_name>get::activedirectoryactions::getmanager</function_name>\\n <function_description>This API is used to identify the manager hierarchy above a given person. Every person could have a manager and the manager could have another manager to which they report to</function_description>\\n <required_argument>alias (string): The alias of the employee under whose name current violations needs to be listed</required_argument>\\n <returns>object: Successful response</returns>\\n <raises>object: Invalid request</raises>\\n </function>\\n \\n </functions>\\n <question>Who are the reportees of David?</question>\\n <scratchpad>\\n After reviewing the functions I was equipped with, I realize I am not able to accurately answer this question since I can't access reportees of David. Therefore, I should explain to the user I cannot answer this question.\\n </scratchpad>\\n <answer>\\n Sorry, I am unable to assist you with this request.\\n </answer>\\n </example>\\n</examples>\\n\\nThe above examples have been provided to you to illustrate general guidelines and format for use of function calling for information retrieval, and how to use your scratchpad to plan your approach. IMPORTANT: the functions provided within the examples should not be assumed to have been provided to you to use UNLESS they are also explicitly given to you within <functions></functions> tags below. All of the values and information within the examples (the questions, function results, and answers) are strictly part of the examples and have not been provided to you.\\n\\nNow that you have read and understood the examples, I will define the functions that you have available to you to use. Here is a comprehensive list.\\n\\n<functions>\\n<function>\\n<function_name>GET::optimal_departure_window_mars::getNextMarsLaunchWindow</function_name>\\n<function_description>Gets the next optimal launch window to Mars.</function_description>\\n<required_argument>specific_impulse (string): Specific impulse of the propulsion system (s).</required_argument>\\n<required_argument>dry_mass (string): Mass of the spacecraft without fuel (kg).</required_argument>\\n<required_argument>total_mass (string): Total mass of the spacecraft including fuel (kg)</required_argument>\\n<returns>object: The next optimal departure date for a Hohmann transfer from Earth to Mars, based on the spacecraft's mass and specific impulse.</returns>\\n</function>\\n\\n\\n</functions>\\n\\nNote that the function arguments have been listed in the order that they should be passed into the function.\\n\\n\\n\\nDo not modify or extend the provided functions under any circumstances. For example, GET::optimal_departure_window_mars::getNextMarsLaunchWindow with additional parameters would be considered modifying the function which is not allowed. Please use the functions only as defined.\\n\\nDO NOT use any functions that I have not equipped you with.\\n\\n Do not make assumptions about inputs; instead, make sure you know the exact function and input to use before you call a function.\\n\\nTo call a function, output the name of the function in between <function_call> and </function_call> tags. You will receive a <function_result> in response to your call that contains information that you can use to better answer the question. Or, if the function call produced an error, you will receive an <error> in response.\\n\\n\\n\\nThe format for all other <function_call> MUST be: <function_call>$FUNCTION_NAME($FUNCTION_PARAMETER_NAME=$FUNCTION_PARAMETER_VALUE)</function_call>\\n\\nRemember, your goal is to answer the user's question to the best of your ability, using only the function(s) provided within the <functions></functions> tags to gather more information if necessary to better answer the question.\\n\\nDo not modify or extend the provided functions under any circumstances. For example, calling GET::optimal_departure_window_mars::getNextMarsLaunchWindow with additional parameters would be modifying the function which is not allowed. Please use the functions only as defined.\\n\\nBefore calling any functions, create a plan for performing actions to answer this question within the <scratchpad>. Double check your plan to make sure you don't call any functions that you haven't been provided with. Always return your final answer within <answer></answer> tags.\\n\\n\\n\\nThe user input is <question>Answer the following question and pay strong attention to the prompt:\\n <question>\\n When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\\n </question>\\n <instruction>\\n You have functions available at your disposal to use when anwering any questions about orbital mechanics.if you can't find a function to answer a question about orbital mechanics, simply reply 'I do not know'\\n </instruction></question>\\n\\n\\nAssistant: <scratchpad> I understand I cannot use functions that have not been provided to me to answer this question.\\n\\n\", \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\", \"type\": \"ORCHESTRATION\"}, \"event_order\": 2}, {\"type\": \"modelInvocationOutput\", \"data\": {\"metadata\": {\"usage\": {\"inputTokens\": 5160, \"outputTokens\": 135}}, \"rawResponse\": {\"content\": \"To answer this question about the next Mars launch window, I will:\\n\\n1. Call the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function to get the next optimal launch window, passing in the provided spacecraft mass and specific impulse values.\\n\\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.\\n\\n</scratchpad>\\n\\n<function_call>\\nGET::optimal_departure_window_mars::getNextMarsLaunchWindow(specific_impulse=\\\"2500\\\", dry_mass=\\\"10000\\\", total_mass=\\\"50000\\\")\"}, \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\"}, \"event_order\": 3}, {\"type\": \"rationale\", \"data\": {\"text\": \"To answer this question about the next Mars launch window, I will:\\n\\n1. Call the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function to get the next optimal launch window, passing in the provided spacecraft mass and specific impulse values.\\n\\nI have verified that I have access to the GET::optimal_departure_window_mars::getNextMarsLaunchWindow function.\", \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\"}, \"event_order\": 4}, {\"type\": \"invocationInput\", \"data\": {\"actionGroupInvocationInput\": {\"actionGroupName\": \"optimal_departure_window_mars\", \"apiPath\": \"/get-next-mars-launch-window\", \"executionType\": \"LAMBDA\", \"parameters\": [{\"name\": \"total_mass\", \"type\": \"string\", \"value\": \"50000\"}, {\"name\": \"dry_mass\", \"type\": \"string\", \"value\": \"10000\"}, {\"name\": \"specific_impulse\", \"type\": \"string\", \"value\": \"2500\"}], \"verb\": \"get\"}, \"invocationType\": \"ACTION_GROUP\", \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\"}, \"event_order\": 5}, {\"type\": \"observation\", \"data\": {\"actionGroupInvocationOutput\": {\"text\": \"{\\\"next_launch_window\\\": {\\\"next_launch_date\\\": \\\"2026-11-26 00:00:00\\\", \\\"synodic_period_days\\\": 779.9068939794238, \\\"transfer_time_days\\\": 259, \\\"delta_v_available_m_s\\\": 39457.985759929674, \\\"delta_v_required_m_s\\\": 5595.997417810693, \\\"is_feasible\\\": true}}\"}, \"traceId\": \"e0b2b2c2-fb7c-4e17-8a1f-a3781100face-0\", \"type\": \"ACTION_GROUP\"}, \"event_order\": 6}]",
"mlflow.spanFunctionName": "\"_action_group_trace\"",
"mlflow.spanInputs": "{\"inner_trace_group\": \"{'actionGroupName': 'optimal_departure_window_mars', 'apiPath': '/get-next-mars-launch-window', 'executionType': 'LAMBDA', 'parameters': [{'name': 'total_mass', 'type': 'string', 'value': '50000'}, {'name': 'dry_mass', 'type': 'string', 'value': '10000'}, {'name': 'specific_impulse', 'type': 'string', 'value': '2500'}], 'verb': 'get'}\"}",
"mlflow.spanOutputs": "\"{'action_group_name': 'optimal_departure_window_mars', 'api_path': '/get-next-mars-launch-window', 'execution_type': 'LAMBDA', 'execution_output': '{\\\"next_launch_window\\\": {\\\"next_launch_date\\\": \\\"2026-11-26 00:00:00\\\", \\\"synodic_period_days\\\": 779.9068939794238, \\\"transfer_time_days\\\": 259, \\\"delta_v_available_m_s\\\": 39457.985759929674, \\\"delta_v_required_m_s\\\": 5595.997417810693, \\\"is_feasible\\\": true}}'}\""
},
"events": []
},
{
"name": "Retrieved Response",
"context": {
"span_id": "0xfe0b5f9149c39d7d",
"trace_id": "0x9b8bd0b2e018d77f936e48a09e54fd44"
},
"parent_id": "0xb802165d133a33aa",
"start_time": 1731388550225320000,
"end_time": 1731388550226466000,
"status_code": "OK",
"status_message": "",
"attributes": {
"mlflow.traceRequestId": "\"1e036cc3a7f946ec995f7763b8dde51c\"",
"mlflow.spanType": "\"AGENT\"",
"mlflow.spanInputs": "[{\"role\": \"user\", \"content\": \"When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\", \"name\": null}]",
"mlflow.spanOutputs": "{\"choices\": [{\"index\": 0, \"message\": {\"role\": \"user\", \"content\": \"Based on the provided spacecraft dry mass of 10000 kg, total mass of 50000 kg, and specific impulse of 2500 s, the next optimal launch window for a Hohmann transfer from Earth to Mars is on November 26, 2026 UTC. The transfer will take 259 days.\", \"name\": null}, \"finish_reason\": \"stop\", \"logprobs\": null}], \"usage\": {\"prompt_tokens\": null, \"completion_tokens\": null, \"total_tokens\": null}, \"id\": null, \"model\": \"anthropic.claude-v2\", \"object\": \"chat.completion\", \"created\": 1731388550}"
},
"events": []
}
],
"request": "{\"context\": \"<mlflow.pyfunc.model.PythonModelContext object at 0x13397c530>\", \"messages\": [{\"role\": \"user\", \"content\": \"When is the next launch window for Mars? My spacecraft's total mass is 50000, dry mass is 10000 and specific impulse is 2500. Mass in Kg.\", \"name\": null}], \"params\": {\"temperature\": 1.0, \"max_tokens\": null, \"stop\": null, \"n\": 1, \"stream\": false, \"top_p\": null, \"top_k\": null, \"frequency_penalty\": null, \"presence_penalty\": null}}",
"response": "{\"choices\": [{\"index\": 0, \"message\": {\"role\": \"user\", \"content\": \"Based on the provided spacecraft dry mass of 10000 kg, total mass of 50000 kg, and specific impulse of 2500 s, the next optimal launch window for a Hohmann transfer from Earth to Mars is on November 26, 2026 UTC. The transfer will take 259 days.\", \"name\": null}, \"finish_reason\": \"stop\", \"logprobs\": null}], \"usage\": {\"prompt_tokens\": null, \"completion_tokens\": null, \"total_tokens\": null}, \"id\": null, \"model\": \"anthropic.claude-v2\", \"object\": \"chat.completion\", \"created\": 1731388550}"
}

Visualizing Trace Breakdown in the MLflow UI

  1. Initial Prompt Submitted to the Bedrock Agent. Thumbnail

  2. In this trace, we can observe how the Bedrock Agent evaluates and selects the most suitable Action Group for the task at hand. Thumbnail

  3. Once an Action Group is selected, its invocation is traced, displaying the input and output interactions with the underlying Lambda function as outlined by the OpenAPI Spec above. Thumbnail

  4. Furthermore, Bedrock's supplementary trace is included under the Attributes section, along with additional metadata as shown below Thumbnail

  5. Subsequently, the final response from the agent is traced, as depicted below. Thumbnail

Note: We cannot break down the span's duration into individual trace durations because the Bedrock Agent's trace response does not include timestamps for each trace step.

Conclusion

In this blog, we explored how to integrate the AWS Bedrock Agent as an MLflow ChatModel, focusing on Action Groups, Knowledge Bases, and Tracing. We demonstrated how to easily build a custom ChatModel using MLflow's flexible and powerful APIs. This approach enables you to leverage MLflow's tracing and logging capabilities, even for models or flavors that are not natively supported by MLflow.

Key Takeaways from This Blog:

  • Deploying a Bedrock Agent with Action Groups as AWS Lambda Functions:
    • We covered how to set up a Bedrock Agent and implement custom actions using AWS Lambda functions within Action Groups.
  • Mapping the AWS Bedrock Agent's Custom Tracing to MLflow span/trace objects:
    • We demonstrated how to convert the agent's custom tracing data into MLflow span objects for better observability.
  • Logging and Loading the Bedrock Agent as an MLflow ChatModel:
    • We showed how to log the Bedrock Agent into MLflow as a ChatModel and how to load it for future use.
  • Externalizing AWS Client and Bedrock Configurations:
    • We explained how to externalize AWS client and Bedrock configurations to safeguard secrets and make it easy to adjust model settings without the need to re-log the model.

Further Reading and References

· 16 min read
Yuki Watanabe

Thumbnail

Augmenting LLMs with various data sources is a strong strategy to build LLM applications. However, as the system grows more complex, it becomes challenging to prototype and iteratively build improvements to these more complex systems.

LlamaIndex Workflow is a great framework to build such compound systems. Combined with MLflow, the Workflow API brings efficiency and robustness in the development cycle, enabling easy debugging, experiment tracking, and evaluation for continuous improvement.

In this blog, we will go through the journey of building a sophisticated chatbot with LlamaIndex's Workflow API and MLflow.

What is LlamaIndex Workflow?

LlamaIndex Workflow is an event-driven orchestration framework for designing dynamic AI applications. The core of LlamaIndex Workflow consists of:

  • Steps are units of execution, representing distinct actions in the workflow.

  • Events trigger these steps, acting as signals that control the workflow’s flow.

  • Workflow connects these two as a Python class. Each step is implemented as a method of the workflow class, defined with input and output events.

This simple yet powerful abstraction allows you to break down complex tasks into manageable steps, enabling greater flexibility and scalability. As a framework embodying event-driven design, using the Workflow APIs makes it intuitive to design parallel and asynchronous execution flows, significantly enhancing the efficiency of long-running tasks and aids in providing production-ready scalability.

Why Use MLflow with LlamaIndex Workflow?

Workflow provides great flexibility to design nearly arbitrary execution flows. However, with this great power comes a great responsibility. Without managing your changes properly, it can become a chaotic mess of indeterminate states and confusing configurations. After a few dozen changes, you may be asking yourself, "how did my workflow even work?".

MLflow brings a powerful MLOps harness to LlamaIndex Workflows throughout the end-to-end development cycle.

  • Experiment Tracking: MLflow allows you to record various components like steps, prompts, LLMs, and tools, making it easy to improve the system iteratively.

  • Reproducibility: MLflow packages environment information such as global configurations (Settings), library versions, and metadata to ensure consistent deployment across different stages of the ML lifecycle.

  • Tracing: Debugging issues in a complex event-driven workflow is cumbersome. MLflow Tracing is a production-ready observability solution that natively integrates with LlamaIndex, giving you observability into each internal stage within your Workflow.

  • Evaluation: Measuring is a crucial task for improving your model. MLflow Evaluation is great tool to evaluate the quality, speed, and cost of your LLM application. It is tightly integrated with MLflow's experiment tracking capabilities, streamlining the process of making iterative improvements.

Let's Build!🛠️

Strategy: Hybrid Approach Using Multiple Retrieval Methods

Retrieval-Augmented Generation (RAG) is a powerful framework, but the retrieval step can often become a bottleneck, because embedding-based retrieval may not always capture the most relevant context. While many techniques exist to improve retrieval quality, no single solution works universally. Therefore, an effective strategy is to combine multiple retrieval approaches.

The concept we will explore here is to run several retrieval methods in parallel: (1) standard vector search, (2) keyword-based search (BM25), and (3) web search. The retrieved contexts are then merged, with irrelevant data filtered out to enhance the overall quality.

Hybrid RAG Concept

How do we bring this concept to life? Let’s dive in and build this hybrid RAG using LlamaIndex Workflow and MLflow.

1. Set Up Repository

The sample code, including the environment setup script, is available in the GitHub repository. It contains a complete workflow definition, a hands-on notebook, and a sample dataset for running experiments. To clone it to your working environment, use the following command:

git clone https://github.com/mlflow/mlflow.git

After cloning the repository, set up the virtual environment by running:

cd mlflow/examples/llama_index/workflow
chmod +x install.sh
./install.sh

Once the installation is complete, start Jupyter Notebook within the Poetry environment using:

poetry run jupyter notebook

Next, open the Tutorial.ipynb notebook located in the root directory. Throughout this blog, we will walk through this notebook to guide you through the development process.

2. Start an MLflow Experiment

An MLflow Experiment is where you track all aspects of model development, including model definitions, configurations, parameters, dependency versions, and more. Let’s start by creating a new MLflow experiment called "LlamaIndex Workflow RAG":

import mlflow

mlflow.set_experiment("LlamaIndex Workflow RAG")

At this point, the experiment doesn't have any recorded data yet. To view the experiment in the MLflow UI, open a new terminal and run the mlflow ui command, then navigate to the provided URL in your browser:

poetry run mlflow ui

Empty MLflow Experiment

3. Choose your LLM and Embeddings

Now, set up your preferred LLM and embeddings models to LlamaIndex's Settings object. These models will be used throughout the LlamaIndex components.

For this demonstration, we’ll use OpenAI models, but you can easily switch to different LLM providers or local models by following the instructions in the notebook.

import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API Key")

from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

# LlamaIndex by default uses OpenAI APIs for LLMs and embeddings models. You can use the default
# model (`gpt-3.5-turbo` and `text-embeddings-ada-002` as of Oct 2024), but we recommend using the
# latest efficient models instead for getting better results with lower cost.
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-large")
Settings.llm = OpenAI(model="gpt-4o-mini")

💡 MLflow will automatically log the Settings configuration into your MLflow Experiment when logging models, ensuring reproducibility and reducing the risk of discrepancies between environments.

4. Set Up Web Search API

Later in this blog, we will add a web search capability to the QA bot. We will use Tavily AI, a search API optimized for LLM application and natively integrated with LlamaIndex. Visit their website to get an API key for free-tier use, or use different search engine integrated with LlamaIndex, e.g. GoogleSearchToolSpec.

Once you get the API key, set it to the environment variable:

os.environ["TAVILY_AI_API_KEY"] = getpass.getpass("Enter Tavily AI API Key")

5. Set Up Document Indices for Retrieval

The next step is to build a document index for retrieval from MLflow documentation. The urls.txt file in the data directory contains a list of MLflow documentation pages. These pages can be loaded as document objects using the web page reader utility.

from llama_index.readers.web import SimpleWebPageReader

with open("data/urls.txt", "r") as file:
urls = [line.strip() for line in file if line.strip()]

documents = SimpleWebPageReader(html_to_text=True).load_data(urls)

Next, ingest these documents into a vector database. In this tutorial, we’ll use the Qdrant vector store, which is free if self-hosted. If Docker is installed on your machine, you can start the Qdrant database by running the official Docker container:

$ docker pull qdrant/qdrant
$ docker run -p 6333:6333 -p 6334:6334 \
-v $(pwd)/.qdrant_storage:/qdrant/storage:z \
qdrant/qdrant

Once the container is running, you can create an index object that connects to the Qdrant database:

import qdrant_client
from llama_index.vector_stores.qdrant import QdrantVectorStore

client = qdrant_client.QdrantClient(host="localhost", port=6333)
vector_store = QdrantVectorStore(client=client, collection_name="mlflow_doc")

from llama_index.core import StorageContext, VectorStoreIndex

storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents=documents,
storage_context=storage_context
)

Of course, you can use your preferred vector store here. LlamaIndex supports a variety of vector databases, such as FAISS, Chroma, and Databricks Vector Search. If you choose an alternative, follow the relevant LlamaIndex documentation and update the workflow/workflow.py file accordingly.

In addition to evaluating the vector search retrieval, we will assess the keyword-based retriever (BM25) later. Let's set up local document storage to enable BM25 retrieval in the workflow.

from llama_index.core.node_parser import SentenceSplitter
from llama_index.retrievers.bm25 import BM25Retriever

splitter = SentenceSplitter(chunk_size=512)
nodes = splitter.get_nodes_from_documents(documents)
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes)
bm25_retriever.persist(".bm25_retriever")

6. Define a Workflow

Now that the environment and data sources are ready, we can build the workflow and experiment with it. The complete workflow code is defined in the workflow directory. Let's explore some key components of the implementation.

Events

The workflow/events.py file defines all the events used within the workflow. These are simple Pydantic models that carry information between workflow steps. For example, the VectorSearchRetrieveEvent triggers the vector search step by passing the user's query.

class VectorSearchRetrieveEvent(Event):
"""Event for triggering VectorStore index retrieval step."""
query: str

Prompts

Throughout the workflow execution, we call LLMs multiple times. The prompt templates for these LLM calls are defined in the workflow/prompts.py file.

Workflow Class

The main workflow class is defined in workflow/workflow.py. Let's break down how it works.

The constructor accepts a retrievers argument, which specifies the retrieval methods to be used in the workflow. For instance, if ["vector_search", "bm25"] is passed, the workflow performs vector search and keyword-based search, skipping web search.

💡 Deciding which retrievers to utilize dynamically allows us to experiment with different retrieval strategies without needing to replicate nearly identical model code.

class HybridRAGWorkflow(Workflow):

VALID_RETRIEVERS = {"vector_search", "bm25", "web_search"}

def __init__(self, retrievers=None, **kwargs):
super().__init__(**kwargs)
self.llm = Settings.llm
self.retrievers = retrievers or []

if invalid_retrievers := set(self.retrievers) - self.VALID_RETRIEVERS:
raise ValueError(f"Invalid retrievers specified: {invalid_retrievers}")

self._use_vs_retriever = "vector_search" in self.retrievers
self._use_bm25_retriever = "bm25" in self.retrievers
self._use_web_search = "web_search" in self.retrievers

if self._use_vs_retriever:
qd_client = qdrant_client.QdrantClient(host=_QDRANT_HOST, port=_QDRANT_PORT)
vector_store = QdrantVectorStore(client=qd_client, collection_name=_QDRANT_COLLECTION_NAME)
index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
self.vs_retriever = index.as_retriever()

if self._use_bm25_retriever:
self.bm25_retriever = BM25Retriever.from_persist_dir(_BM25_PERSIST_DIR)

if self._use_web_search:
self.tavily_tool = TavilyToolSpec(api_key=os.environ.get("TAVILY_AI_API_KEY"))

The workflow begins by executing a step that takes the StartEvent as input, which is the route_retrieval step in this case. This step inspects the retrievers parameter and triggers the necessary retrieval steps. By using the send_event() method of the context object, multiple events can be dispatched in parallel from this single step.

    # If no retriever is specified, proceed directly to the final query step with an empty context
if len(self.retrievers) == 0:
return QueryEvent(context="")

# Trigger the retrieval steps based on the configuration
if self._use_vs_retriever:
ctx.send_event(VectorSearchRetrieveEvent(query=query))
if self._use_bm25_retriever:
ctx.send_event(BM25RetrieveEvent(query=query))
if self._use_web_search:
ctx.send_event(TransformQueryEvent(query=query))

The retrieval steps are straightforward. However, the web search step is more advanced as it includes an additional step to transform the user's question into a search-friendly query using an LLM.

The results from all the retrieval steps are aggregated in the gather_retrieval_results step. Here, the ctx.collect_events() method is used to poll for the results of the asynchronously executed steps.

    results = ctx.collect_events(ev, [RetrievalResultEvent] * len(self.retrievers))

Passing all results from multiple retrievers often leads to a large context with unrelated or duplicate content. To address this, we need to filter and select the most relevant results. While a score-based approach is common, web search results do not return similarity scores. Therefore, we use an LLM to sort and filter out irrelevant results. The rerank step achieves this by leveraging the built-in reranker integration with RankGPT.

    reranker = RankGPTRerank(llm=self.llm, top_n=5)
reranked_nodes = reranker.postprocess_nodes(ev.nodes, query_str=query)
reranked_context = "\n".join(node.text for node in reranked_nodes)

Finally, the reranked context is passed to the LLM along with the user query to generate the final answer. The result is returned as a StopEvent with the result key.

    @step
async def query_result(self, ctx: Context, ev: QueryEvent) -> StopEvent:
"""Get result with relevant text."""
query = await ctx.get("query")

prompt = FINAL_QUERY_TEMPLATE.format(context=ev.context, query=query)
response = self.llm.complete(prompt).text
return StopEvent(result=response)

Now, let's instantiate the workflow and run it.

# Workflow with VS + BM25 retrieval
from workflow.workflow import HybridRAGWorkflow

workflow = HybridRAGWorkflow(retrievers=["vector_search", "bm25"], timeout=60)
response = await workflow.run(query="Why use MLflow with LlamaIndex?")
print(response)

7. Log the Workflow in an MLflow Experiment

Now we want to run the workflow with various different retrieval strategies and evaluate the performance of each. However, before running the evaluation, we'll log the model in MLflow to track both the model and its performance within an MLflow Experiment.

For the LlamaIndex Workflow, we use the new Model-from-code method, which logs models as standalone Python scripts. This approach avoids the risks and instability associated with serialization methods like pickle, relying instead on code as the single source of truth for the model definition. When combined with MLflow's environment-freezing capability, it provides a reliable way to persist models. For more details, refer to the MLflow documentation.

💡 In the workflow directory, there's a model.py script that imports the HybridRAGWorkflow and instantiates it with dynamic configurations passed via the model_config parameter during logging. This design allows you to track models with different configurations without duplicating the model definition.

We'll start an MLflow Run and log the model script model.py with different configurations using the mlflow.llama_index.log_model() API.

# Different configurations we will evaluate. We don't run evaluation for all permutation
# for demonstration purpose, but you can add as many patterns as you want.
run_name_to_retrievers = {
# 1. No retrievers (prior knowledge in LLM).
"none": [],
# 2. Vector search retrieval only.
"vs": ["vector_search"],
# 3. Vector search and keyword search (BM25)
"vs + bm25": ["vector_search", "bm25"],
# 4. All retrieval methods including web search.
"vs + bm25 + web": ["vector_search", "bm25", "web_search"],
}

# Create an MLflow Run and log model with each configuration.
models = []
for run_name, retrievers in run_name_to_retrievers.items():
with mlflow.start_run(run_name=run_name):
model_info = mlflow.llama_index.log_model(
# Specify the model Python script.
llama_index_model="workflow/model.py",
# Specify retrievers to use.
model_config={"retrievers": retrievers},
# Define dependency files to save along with the model
code_paths=["workflow"],
# Subdirectory to save artifacts (not important)
artifact_path="model",
)
models.append(model_info)

Now open the MLflow UI again, and this time it should show 4 MLflow Runs are recorded with different retrievers parameter values. By clicking each Run name and navigate to the "Artifacts" tab, you can see MLflow records the model and various metadata, such as dependency versions and settings.

MLflow Runs

8. Enable MLflow Tracing

Before running the evaluation, there’s one final step: enabling MLflow Tracing. We'll dive into this feature and why we do this here later, but for now, you can enable it with a simple one-line command. MLflow will automatically trace every LlamaIndex execution.

mlflow.llama_index.autolog()

9. Evaluate the Workflow with Different Retriever Strategies

The example repository includes a sample evaluation dataset, mlflow_qa_dataset.csv, containing 30 question-answer pairs related to MLflow.

import pandas as pd

eval_df = pd.read_csv("data/mlflow_qa_dataset.csv")
display(eval_df.head(3))

To evaluate the workflow, use the mlflow.evaluate() API, which requires (1) your dataset, (2) the logged model, and (3) the metrics you want to compute.

from mlflow.metrics import latency
from mlflow.metrics.genai import answer_correctness


for model_info in models:
with mlflow.start_run(run_id=model_info.run_id):
result = mlflow.evaluate(
# Pass the URI of the logged model above
model=model_info.model_uri,
data=eval_df,
# Specify the column for ground truth answers.
targets="ground_truth",
# Define the metrics to compute.
extra_metrics=[
latency(),
answer_correctness("openai:/gpt-4o-mini"),
],
# The answer_correctness metric requires "inputs" column to be
# present in the dataset. We have "query" instead so need to
# specify the mapping in `evaluator_config` parameter.
evaluator_config={"col_mapping": {"inputs": "query"}},
)

In this example, we evaluate the model with two metrics:

  1. Latency: Measures the time taken to execute a workflow for a single query.
  2. Answer Correctness: Evaluates the accuracy of answers based on the ground truth, scored by the OpenAI GPT-4o model on a 1–5 scale.

These metrics are just for demonstration purposes—you can add additional metrics like toxicity or faithfulness, or even create your own. See the MLflow documentation for the full set of built-in metrics and how to define custom metrics.

The evaluation process will take a few minutes. Once completed, you can view the results in the MLflow UI. Open the Experiment page and click on the chart icon 📈 above the Run list.

Evaluation Result

*💡 The evaluation results can be different depending on model set up and some randomness.

The first row shows bar charts for the answer correctness metrics, while the second row displays latency results. The best-performing combination is "Vector Search + BM25". Interestingly, adding web search not only increases latency significantly but also decreases answer correctness.

Why does this happen? It appears some answers from the web-search-enabled model are off-topic. For example, in response to a question about starting the Model Registry, the web-search model provides an unrelated answer about model deployment, while the "vs + bm25" model offers a correct response.

Answer Comparison

Where did this incorrect answer come from? This seems to be a retriever issue, as we only changed the retrieval strategy. However, it's difficult to see what each retriever returned from the final result. To gain deeper insights into what's happening behind the scenes, MLflow Tracing is the perfect solution.

10. Inspecting Quality Issues with MLflow Trace

MLflow Tracing is a new feature that brings observability to LLM applications. It integrates seamlessly with LlamaIndex, recording all inputs, outputs, and metadata about intermediate steps during workflow execution. Since we called mlflow.llama_index.autolog() at the start, every LlamaIndex operation has been traced and recorded in the MLflow Experiment.

To inspect the trace for a specific question from the evaluation, navigate to the "Traces" tab on the experiment page. Look for the row with the particular question in the request column and the run name "vs + bm25 + web." Clicking the request ID link opens the Trace UI, where you can view detailed information about each step in the execution, including inputs, outputs, metadata, and latency.

Trace

In this case, we identified the issue by examining the reranker step. The web search retriever returned irrelevant context related to model serving, and the reranker incorrectly ranked it as the most relevant. With this insight, we can determine potential improvements, such as refining the reranker to better understand MLflow topics, improving web search precision, or even removing the web search retriever altogether.

Conclusion

In this blog, we explored how the combination of LlamaIndex and MLflow can elevate the development of Retrieval-Augmented Generation (RAG) workflows, bringing together powerful model management and observability capabilities. By integrating multiple retrieval strategies (such as vector search, BM25, and web search) we demonstrated how flexible retrieval can enhance the performance of LLM-driven applications.

  • Experiment Tracking allowed us to organize and log different workflow configurations, ensuring reproducibility and enabling us to track model performance across multiple runs.
  • MLflow Evaluate enabled us to easily log and evaluate the workflow with different retriever strategies, using key metrics like latency and answer correctness to compare performance.
  • MLflow UI gave us a clear visualization of how various retrieval strategies impacted both accuracy and latency, helping us identify the most effective configurations.
  • MLflow Tracing, integrated with LlamaIndex, provided detailed observability into each step of the workflow for diagnosing quality issues, such as incorrect reranking of search results.

With these tools, you have a complete framework for building, logging, and optimizing RAG workflows. As LLM technology continues to evolve, the ability to track, evaluate, and fine-tune every aspect of model performance will be essential. We highly encourage you to experiment further and see how these tools can be tailored to your own applications.

To continue learning, explore the following resources:

· 17 min read
Pedro Azevedo
Rahul Pandey

In this blog post, we'll dive on a journey to revolutionize how we evaluate language models. We'll explore the power of MLflow Evaluate and harness the capabilities of Large Language Models (LLMs) as judges. By the end, you'll learn how to create custom metrics, implement LLM-based evaluation, and apply these techniques to real-world scenarios. Get ready to transform your model assessment process and gain deeper insights into your AI's performance!

The Challenge of Evaluating Language Models

Evaluating large language models (LLMs) and natural language processing (NLP) systems presents several challenges, primarily due to their complexity and the diversity of tasks they can perform.

One major difficulty is creating metrics that comprehensively measure performance across varied applications, from generating coherent text to understanding nuanced human emotions. Traditional benchmarks often fail to capture these subtleties, leading to incomplete assessments.

An LLM acting as a judge can address these issues by leveraging its extensive training data to provide a more nuanced evaluation, offering insights into model behavior and areas needing improvement. For instance, an LLM can analyze whether a model generates text that is not only grammatically correct but also contextually appropriate and engaging, something more static metrics might miss.

However, to move forward effectively, we need more than just better evaluation methods. Standardized experimentation setups are essential to ensure that comparisons between models are both fair and replicable. A uniform framework for testing and evaluation would enable researchers to build on each other's work, leading to more consistent progress and the development of more robust models.

Introducing MLflow LLM Evaluate

MLflow LLM Evaluate is a powerful function within the MLflow ecosystem that allows for comprehensive model assessment by providing a standardized experiment setup. It supports both built-in metrics and custom (LLM) metrics, making it an ideal tool for evaluating complex language tasks. With MLflow LLM Evaluate, you can:

MLflow Evaluate

Conquering new markets with an LLM as a judge

Imagine you're part of a global travel agency, "WorldWide Wandercorp," that's expanding its reach to Spanish-speaking countries.

Your team has developed an AI-powered translation system to help create culturally appropriate marketing materials and customer communications. However, as you begin to use this system, you realize that traditional evaluation metrics, such as BLEU (Bilingual Evaluation Understudy), fall short in capturing the nuances of language translation, especially when it comes to preserving cultural context and idiomatic expressions.

For instance, consider the phrase "kick the bucket." A direct translation might focus on the literal words, but the idiom actually means "to die." A traditional metric like BLEU may incorrectly evaluate the translation as adequate if the translated words match a reference translation, even if the cultural meaning is lost. In such cases, the metric might score the translation highly despite it being completely inappropriate in context. This could lead to embarrassing or culturally insensitive marketing content, which is something your team wants to avoid.

You need a way to evaluate whether the translation not only is accurate but also preserves the intended meaning, tone, and cultural context. This is where MLflow Evaluate and LLMs (Large Language Models) as judges come into play. These tools can assess translations more holistically by considering context, idiomatic expressions, and cultural relevance, providing a more reliable evaluation of the AI’s output.

Custom Metrics: Tailoring Evaluation to Your Needs

In the following section, we’ll implement three metrics:

  • The "cultural_sensitivity" metric ensures translations maintain cultural context and appropriateness.
  • The "faithfulness" metric checks that chatbot responses align accurately with company policies and retrieved content.
  • The "toxicity" metric evaluates responses for harmful or inappropriate content, ensuring respectful customer interactions.

These metrics will help Worldwide WanderAgency ensure their AI-driven translations and interactions meet their specific needs.

Evaluating Worldwide WanderAgency's AI Systems

Now that we understand WanderAgency's challenges, let's dive into a code walkthrough to address them. We'll implement custom metrics to measure AI performance and build a gauge visualization chart for sharing results with stakeholders.

We'll start by evaluating a language translation model, focusing on the "cultural_sensitivity" metric to ensure it preserves cultural nuances. This will help WanderAgency maintain high standards in global communication.

Cultural Sensitivity Metric

The travel agency wants to ensure their translations are not only accurate but also culturally appropriate. To achieve this they are considering creating a custom metric that allows Worldwide WanderAgency to quantify how well their translations maintain cultural context and idiomatic expressions.

For instance, a phrase that is polite in one culture might be inappropriate in another. In English, addressing someone as "Dear" in a professional email might be seen as polite. However, in Spanish, using "Querido" in a professional context can be too personal and inappropriate.

How can we evaluate such an abstract concept in a systematic way? Traditional Metrics would fall short so we need a better way of doing it. In this case LLM as a judge would be a great fit! For this use case let's create a "cultural_sensitivity" metric.

Here's a brief overview of the process: Start by installing all the necessary libraries for this demo to work.

pip install mlflow>=2.14.1 openai  transformers torch torchvision evaluate datasets tiktoken fastapi rouge_score textstat tenacity plotly ipykernel nbformat>=5.10.4

We will be using gpt3.5 and gpt4 during this example for that let's start by making sure our OpenAI key is setup.

Import the necessary libraries.

import mlflow
import os

# Run a quick validation that we have an entry for the OPEN_API_KEY within environment variables

assert "OPENAI_API_KEY" in os.environ, "OPENAI_API_KEY environment variable must be set"

import openai
import pandas as pd

When using the mlflow.evaluate() function, your large language model (LLM) can take one of the following forms:

  1. A mlflow.pyfunc.PyFuncModel() — typically an MLflow model.
  2. A Python function that accepts strings as inputs and returns a single string as output.
  3. An MLflow Deployments endpoint URI.
  4. model=None if the data you are providing has already been scored by a model, and you do not need to specify one.

For this example, we will use an MLflow model.

We’ll begin by logging a translation model in MLflow. For this tutorial, we'll use GPT-3.5 with a defined system prompt.

In a production environment, you would typically experiment with different prompts and models to determine the most suitable configuration for your use case. For more details, refer to MLflow’s Prompt Engineering UI.


system_prompt = "Translate the following sentences into Spanish"
# Let's set up an experiment to make it easier to track our results
mlflow.set_experiment("/Path/to/your/experiment")

basic_translation_model = mlflow.openai.log_model(
model="gpt-3.5-turbo",
task=openai.chat.completions,
artifact_path="model",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": "{user_input}"},
],
)

Let's test the model to make sure it works.

model = mlflow.pyfunc.load_model(basic_translation_model.model_uri)

model.predict("Hello, how are you?")

# Output = ['¡Hola, ¿cómo estás?']

To use mlflow.evaluate(), we first need to prepare sample data that will serve as input to our LLM. In this scenario, the input would consist of the content the company is aiming to translate.

For demonstration purposes, we will define a set of common English expressions that we want the model to translate.

# Prepare evaluation data
eval_data = pd.DataFrame(
{
"llm_inputs": [
"I'm over the moon about the news!",
"Spill the beans.",
"Bite the bullet.",
"Better late than never.",

]
}
)

To meet the objectives of the travel agency, we will define custom metrics that evaluate the quality of translations. In particular, we need to assess how faithfully the translations capture not only the literal meaning but also cultural nuances.

By default, mlflow.evaluate() uses openai:/gpt-4 as the evaluation model. However, you also have the option to use a local model for evaluation, such as a model wrapped in a PyFunc (e.g., Ollama).

For this example, we will use GPT-4 as the evaluation model.

To begin, provide a few examples that illustrate good and poor translation scores.

# Define the custom metric
cultural_sensitivity = mlflow.metrics.genai.make_genai_metric(
name="cultural_sensitivity",
definition="Assesses how well the translation preserves cultural nuances and idioms.",
grading_prompt="Score from 1-5, where 1 is culturally insensitive and 5 is highly culturally aware.",
examples=[
mlflow.metrics.genai.EvaluationExample(
input="Break a leg!",
output="¡Rómpete una pierna!",
score=2,
justification="This is a literal translation that doesn't capture the idiomatic meaning."
),
mlflow.metrics.genai.EvaluationExample(
input="Break a leg!",
output="¡Mucha mierda!",
score=5,
justification="This translation uses the equivalent Spanish theater idiom, showing high cultural awareness."
),
mlflow.metrics.genai.EvaluationExample(
input="It's raining cats and dogs.",
output="Está lloviendo gatos y perros.",
score=1,
justification="This literal translation does not convey the idiomatic meaning of heavy rain."
),
mlflow.metrics.genai.EvaluationExample(
input="It's raining cats and dogs.",
output="Está lloviendo a cántaros.",
score=5,
justification="This translation uses a Spanish idiom that accurately conveys the meaning of heavy rain."
),
mlflow.metrics.genai.EvaluationExample(
input="Kick the bucket.",
output="Patear el balde.",
score=1,
justification="This literal translation fails to convey the idiomatic meaning of dying."
),
mlflow.metrics.genai.EvaluationExample(
input="Kick the bucket.",
output="Estirar la pata.",
score=5,
justification="This translation uses the equivalent Spanish idiom for dying, showing high cultural awareness."
),
mlflow.metrics.genai.EvaluationExample(
input="Once in a blue moon.",
output="Una vez en una luna azul.",
score=2,
justification="This literal translation does not capture the rarity implied by the idiom."
),
mlflow.metrics.genai.EvaluationExample(
input="Once in a blue moon.",
output="De vez en cuando.",
score=4,
justification="This translation captures the infrequency but lacks the idiomatic color of the original."
),
mlflow.metrics.genai.EvaluationExample(
input="The ball is in your court.",
output="La pelota está en tu cancha.",
score=3,
justification="This translation is understandable but somewhat lacks the idiomatic nuance of making a decision."
),
mlflow.metrics.genai.EvaluationExample(
input="The ball is in your court.",
output="Te toca a ti.",
score=5,
justification="This translation accurately conveys the idiomatic meaning of it being someone else's turn to act."
)
],
model="openai:/gpt-4",
parameters={"temperature": 0.0},
)

The Toxicity Metric

In addition to this custom metric let's use MLflow built-in metrics for the evaluators. In this case MLflow wll use roberta-hate-speech model to detect the toxicity. This metric evaluates responses for any harmful or inappropriate content, reinforcing the company's commitment to a positive customer experience.

# Log and evaluate the model
with mlflow.start_run() as run:
results = mlflow.evaluate(
basic_translation_model.model_uri,
data=eval_data,
model_type="text",
evaluators="default",
extra_metrics=[cultural_sensitivity],
evaluator_config={
"col_mapping": {
"inputs": "llm_inputs",
}}
)

mlflow.end_run()

You can retrieve the final results as such:

results.tables["eval_results_table"]
llm_inputsoutputstoken_counttoxicity/v1/scoreflesch_kincaid_grade_level/v1/scoreari_grade_level/v1/scorecultural_sensitivity/v1/scorecultural_sensitivity/v1/justification
0I'm over the moon about the news!¡Estoy feliz por la noticia!90.0002585.23.74The translation captures the general sentiment...
1Spill the beans.Revela el secreto.70.0010179.25.25The translation accurately captures the idioma...
2Bite the bullet.Morder la bala.70.0015860.93.62The translation "Morder la bala" is a litera...
3Better late than never.Más vale tarde que nunca.70.0049470.50.95The translation accurately captures the idioma...

Let's analyze the final metrics...

cultural_sensitivity_score = results.metrics['cultural_sensitivity/v1/mean']
print(f"Cultural Sensitivity Score: {cultural_sensitivity_score}")

toxicity_score = results.metrics['toxicity/v1/mean']
# Calculate non-toxicity score
non_toxicity_score = "{:.2f}".format((1 - toxicity_score) * 100)
print(f"Non-Toxicity Score: {non_toxicity_score}%")

Output:

Cultural Sensitivity Score: 3.75
Pureness Score: 99.80

It is often the case we want to monitor and track these metrics on a dashboard so both data scientists and stakeholders have an understanding of the performance and reliability of these solutions.

For this example let's create a gauge to display the final metric.

import plotly.graph_objects as go
from plotly.subplots import make_subplots

def create_gauge_chart(value1, title1, value2, title2):
# Create a subplot figure with two columns
fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'indicator'}, {'type': 'indicator'}]])

# Add the first gauge chart
fig.add_trace(go.Indicator(
mode = "gauge+number",
value = value1,
title = {'text': title1},
gauge = {'axis': {'range': [None, 5]}}
), row=1, col=1)

# Add the second gauge chart
fig.add_trace(go.Indicator(
mode = "gauge+number",
value = value2,
title = {'text': title2},
gauge = {'axis': {'range': [None, 100]}}
), row=1, col=2)

# Update layout
fig.update_layout(height=400, width=800)

# Show figure
fig.show()
create_gauge_chart(cultural_sensitive_score, "Cultural Sensitivity Score", float(non_toxicity_score), "Non Toxicity Score")

Gauge Chart

The Faithfulness Metric

As Worldwide WanderAgency's AI grows, they add a customer service chatbot that handles questions in multiple languages. This chatbot uses a RAG (Retrieval-Augmented Generation) system, which means it retrieves information from a database or documents and then generates an answer based on that information.

It's important that the answers provided by the chatbot stay true to the information it retrieves. To make sure of this, we create a "faithfulness" metric. This metric checks how well the chatbot's responses match the materials it’s supposed to be based on, ensuring the information given to customers is accurate.

For example, If the retrieved document says "Returns are accepted within 30 days," and the chatbot replies with "Our return policy is flexible and varies by region," it is not aligning well with the retrieved material. This inaccurate response (bad faithfulness) could mislead customers and create confusion.

Using MLflow to Evaluate RAG - Faithfulness

Let's evaluate how well our chatbot is doing in sticking to the retrieved information. Instead of using an MLflow model this time, we’ll use a custom function to define the faithfulness metric and see how aligned the chatbot's answers are with the data it pulls from.

# Prepare evaluation data
eval_data = pd.DataFrame(
{
"llm_inputs": [
"""Question: What is the company's policy on employee training?
context: "Our company offers various training programs to support employee development. Employees are required to complete at least one training course per year related to their role. Additional training opportunities are available based on performance reviews." """,
"""Question: What is the company's policy on sick leave?
context: "Employees are entitled to 10 days of paid sick leave per year. Sick leave can be used for personal illness or to care for an immediate family member. A doctor's note is required for sick leave exceeding three consecutive days." """,
"""Question: How does the company handle performance reviews?
context: "Performance reviews are conducted annually. Employees are evaluated based on their job performance, goal achievement, and overall contribution to the team. Feedback is provided, and development plans are created to support employee growth." """,
]
}
)

Now let's define some examples for this faithfulness metric.

examples = [
mlflow.metrics.genai.EvaluationExample(
input="""Question: What is the company's policy on remote work?
context: "Our company supports a flexible working environment. Employees can work remotely up to three days a week, provided they maintain productivity and attend all mandatory meetings." """,
output="Employees can work remotely up to three days a week if they maintain productivity and attend mandatory meetings.",
score=5,
justification="The answer is accurate and directly related to the question and context provided."
),
mlflow.metrics.genai.EvaluationExample(
input="""Question: What is the company's policy on remote work?
context: "Our company supports a flexible working environment. Employees can work remotely up to three days a week, provided they maintain productivity and attend all mandatory meetings." """,
output="Employees are allowed to work remotely as long as they want.",
score=2,
justification="The answer is somewhat related but incorrect because it does not mention the three-day limit."
),
mlflow.metrics.genai.EvaluationExample(
input="""Question: What is the company's policy on remote work?
context: "Our company supports a flexible working environment. Employees can work remotely up to three days a week, provided they maintain productivity and attend all mandatory meetings." """,
output="Our company supports flexible work arrangements.",
score=3,
justification="The answer is related to the context but does not specifically answer the question about the remote work policy."
),
mlflow.metrics.genai.EvaluationExample(
input="""Question: What is the company's annual leave policy?
context: "Employees are entitled to 20 days of paid annual leave per year. Leave must be approved by the employee's direct supervisor and should be planned in advance to ensure minimal disruption to work." """,
output="Employees are entitled to 20 days of paid annual leave per year, which must be approved by their supervisor.",
score=5,
justification="The answer is accurate and directly related to the question and context provided."
)]

# Define the custom metric
faithfulness = mlflow.metrics.genai.make_genai_metric(
name="faithfulness",
definition="Assesses how well the answer relates to the question and provided context.",
grading_prompt="Score from 1-5, where 1 is not related at all and 5 is highly relevant and accurate.",
examples=examples)

Define out LLM function (in this case it can be any function that follows certain input/output formats that mlflow.evaluate()).

# Using custom function
def my_llm(inputs):
answers = []
system_prompt = "Please answer the following question in formal language based on the context provided."
for index, row in inputs.iterrows():
print('INPUTS:', row)
completion = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"{row}"},
],
)
answers.append(completion.choices[0].message.content)

return answers

Resulting in a code that is similar to what we did before...

with mlflow.start_run() as run:
results = mlflow.evaluate(
my_llm,
eval_data,
model_type="text",
evaluators="default",
extra_metrics=[faithfulness],
evaluator_config={
"col_mapping": {
"inputs": "llm_inputs",
}}
)
mlflow.end_run()

GenAI Metrics

Alternatively, we can leverage MLflow's built-in metrics for generative AI, using the same examples.

MLflow provides several built-in metrics that use an LLM as a judge. Despite differences in implementation, these metrics are used in the same way. Simply include them in the extra_metrics argument of the mlflow.evaluate() function.

In this case, we will use MLflow’s built-in faithfulness metric.

from mlflow.metrics.genai import EvaluationExample, faithfulness
faithfulness_metric = faithfulness(model="openai:/gpt-4")
print(faithfulness_metric)

mlflow.evaluate() simplifies the process of providing grading context, such as the documents retrieved by our system, directly into the evaluation. This feature integrates seamlessly with LangChain's retrievers, allowing you to supply the context for evaluation as a dedicated column. For more details, refer to this example.

In this case, since our retrieved documents are already included within the final prompt and we are not leveraging LangChain for this tutorial, we will simply map the llm_input column as our grading context.

with mlflow.start_run() as run:
results = mlflow.evaluate(
my_llm,
eval_data,
model_type="text",
evaluators="default",
extra_metrics=[faithfulness_metric],
evaluator_config={
"col_mapping": {
"inputs": "llm_inputs",
"context": "llm_inputs",
}}
)
mlflow.end_run()

After the evaluation we get the following results: Gauge faithfulness Chart

Conclusion

By combining the Cultural Sensitivity score with our other calculated metrics, our travel agency can further refine its model to ensure the delivery of high-quality content across all languages. Moving forward, we can revisit and adjust the prompts used to boost our Cultural Sensitivity score. Alternatively, we could fine-tune a smaller model to maintain the same high level of cultural sensitivity while reducing costs. These steps will help us provide even better service to the agency's diverse customer base.

mlflow.evaluate(), combined with LLMs as judges, opens up new possibilities for nuanced and context-aware model evaluation. By creating custom metrics tailored to specific aspects of model performance, data scientists can gain deeper insights into their models' strengths and weaknesses.

The flexibility offered by make_genai_metric() allows you to create evaluation criteria that are perfectly suited to your specific use case. Whether you need structured guidance for your LLM judge or want full control over the prompting process, MLflow provides the tools you need.

As you explore MLflow evaluate and LLM-based metrics, remember that the key lies in designing thoughtful evaluation criteria and providing clear instructions to your LLM judge. With these tools at your disposal, you're well-equipped to take your model evaluation to the next level, ensuring that your language models not only perform well on traditional metrics but also meet the nuanced requirements of real-world applications.

The built-in metrics, such as toxicity, offer standardized assessments that are crucial for ensuring the safety and accessibility of model outputs.

As a final challenge, re-run all the tests performed but this time with "gpt-4o-mini" and see how the performance is affected.

· 12 min read
Awadelrahman M. A. Ahmed

We all (well, most of us) remember November 2022 when the public release of ChatGPT by OpenAI marked a significant turning point in the world of AI. While generative artificial intelligence (GenAI) had been evolving for some time, ChatGPT, built on OpenAI's GPT-3.5 architecture, quickly captured the public’s imagination. This led to an explosion of interest in GenAI, both within the tech industry and among the general public.

On the tools side, MLflow continues to solidify its position as the favorite tool for (machine learning operations) MLOps among the ML community. However, the rise of GenAI has introduced new needs in how we use MLflow. One of these new challenges is how we log models in MLflow. If you’ve used MLflow before (and I bet you have), you’re probably familiar with the mlflow.log_model() function and how it efficiently pickles model artifacts.

Particularly with GenAI, there’s a new requirement: logging the models "from code", instead of serializing it into a pickle file! And guess what? This need isn’t limited to GenAI models! So, in this post I will explore this concept and how MLflow has adapted to meet this new requirement.

You will notice that this feature is implemented at a very abstract level, allowing you to log any model "as code", whether it’s GenAI or not! I like to think of it as a generic approach, with GenAI models being just one of its use cases. So, in this post, I’ll explore this new feature, "Models from Code logging".

By the end of this post, you should be able to answer the three main questions: 'What,' 'Why,' and 'How' to use Models from Code logging.

What Is Models from Code Logging?

In fact, when MLflow announced this feature, it got me thinking in a more abstract way about the concept of a "model"! You might find it interesting as well, if you zoom out and consider a model as a mathematical representation or function that describes the relationship between input and output variables. At this level of abstraction, a model can be many things!

One might even recognize that a model, as an object or artifact, represents just one form of what a model can be, even if it’s the most popular in the ML community. If you think about it, a model can also be as simple as a piece of code for a mapping function or a code that sends API requests to external services such as OpenAI's APIs.

I'll explain the detailed workflow of how to log models from code later in the post, but for now, let's consider it at a high level with two main steps: first, writing your model code, and second, logging your model from code. This will look like the following figure:

High Level Models from Code Logging Workflow:

High Level Models-from-Code Logging Workflow

🔴 It's important to note that when we refer to "model code," we're talking about code that can be treated as a model itself. This means it's not your training code that generates a trained model object, but rather the step-by-step code that is executed as a model itself.

How Models from Code Differs from Object-Based Logging?

In the previous section, we discussed the concept of Models from Code logging. However, concepts often become clearer when contrasted with their alternatives; a technique known as contrast learning. In our case, the alternative is Object-Based logging, which is the commonly used approach for logging models in MLflow.

Object-Based logging treats a trained model as an object that can be stored and reused. After training, the model is saved as an object and can be easily loaded for deployment. For example, this process can be initiated by calling mlflow.log_model(), where MLflow handles the serialization, often using Pickle or similar methods.

Object-Based logging can be broken down into three high-level steps as in the following figure: first, creating the model object (whether by training it or acquiring it), second, serializing it (usually with Pickle or a similar tool), and third, logging it as an object.

High Level Object-Based Logging Workflow:

High Level Object-Based Logging Workflow

💡The main distinction between the popular Object-Based logging and Models from Code logging is that in the former, we log the model object itself, whether it's a model you've trained or a pre-trained model you've acquired. In the latter, however, we log the code that represents your model.

When Do You Need Models from Code Logging?

By now, I hope you have a clear understanding of what Models from Code logging is! You might still be wondering, though, about the specific use cases where this feature can be applied. This section will cover exactly that—the why!

While we mentioned GenAI as a motivational use case in the introduction, we also highlighted that MLflow has approached Models from Code logging in a more generic way and we will see that in the next section. This means you can leverage the generalizability of the Models from Code feature for a wide range of scenarios. I’ve identified three key usage patterns that I believe are particularly relevant:

1️⃣ When Your Model Relies on External Services:

This is one of the obvious and common use cases, especially with the rise of modern AI applications. It’s becoming increasingly clear that we are shifting from building AI at the "model" granularity to the "system" granularity.

In other words, AI is no longer just about individual models; it’s about how those models interact within a broader ecosystem. As we become more dependent on external AI services and APIs, the need for Models from Code logging becomes more pronounced.

For instance, frameworks like LangChain allow developers to build applications that chain together various AI models and services to perform complex tasks, such as language understanding and information retrieval. In such scenarios, the "model" is not just a set of trained parameters that can be pickled but a "system" of interconnected services, often orchestrated by code that makes API calls to external platforms.

Models from Code logging in these situations ensures that the entire workflow, including the logic and dependencies, is preserved. It offers is the ability to maintain the same model-like experience by capturing the code making it possible to faithfully recreate the model’s behavior, even when the actual computational work is performed outside your domain.

2️⃣ When You’re Combining Multiple Models to Calculate a Complex Metric:

Apart from GenAI, you can still benefit from the Models from Code feature in various other domains. There are many situations where multiple specialized models are combined to produce a comprehensive output. Note that we are not just referring to traditional ensemble modeling (predicting the same variable); often, you need to combine multiple models to predict different components of a complex inferential task.

One concrete example could be Customer Lifetime Value (CLV) in customer analytics. In the context of CLV, you might have separate models for:

  • Customer Retention: Forecasting how long a customer will continue to engage with the business.
  • Purchase Frequency: Predicting how often a customer will make a purchase.
  • Average Order Value: Estimating the typical value of each transaction.

Each of these models might already be logged and tracked properly using MLflow. Now, you need to "combine" these models into a single "system" that calculates CLV. We refer to it as a "system" because it contains multiple components.

The beauty of MLflow's Models from Code logging is that it allows you to treat this "CLV system" as a "CLV model". It enables you to leverage MLflow's capabilities, maintaining the MLflow-like model structure with all the advantages of tracking, versioning, and deploying your CLV model as a cohesive unit, even though it's built on top of other models. While such a complex model system is able to be built using a custom MLflow PythonModel, utilizing the Models from Code feature dramatically simplifies the serialization process, reducing the friction to building your solution.

3️⃣ When You Don’t Have Serialization at All:

Despite the rise of deep learning, industries still rely on rule-based algorithms that don’t produce serialized models. In these cases, Models from Code logging can be beneficial for integrating these processes into the MLflow ecosystem.

One example is in industrial quality control, where the Canny edge detection algorithm is often used to identify defects. This rule-based algorithm doesn’t involve serialization but is defined by specific steps.

Another example, which is gaining attention nowadays, is Causal AI. Constraint-based causal discovery algorithms like the PC (Peter-Clark) algorithm that discover causal relationships in data but are implemented as code rather than as model objects.

In either case, with the Models from Code feature, you can log the entire process as a "model" in MLflow, preserving the logic and parameters while benefiting from MLflow’s tracking and versioning features.

How To Implement Models from Code Logging?

I hope that by this point, you have a clear understanding of the "What" and "Why" of Models from Code, and now you might be eager to get hands-on and focus on the How!

In this section, I'll provide a generic workflow for implementing MLflow's Models from Code logging, followed by a basic yet broadly applicable example. I hope the workflow provides a broad understanding that allows you to address a wide range of scenarios. I will also include links at the end to resources that cover more specific use cases (e.g., AI models).

Models from Code Workflow:

A key "ingredient" of the implementation is MLflow's component pyfunc. If you're not familiar with it, think of pyfunc as a universal interface in MLflow that lets you turn any model, from any framework, into an MLflow model by defining a custom Python function. You can also refer to this earlier post if you wish to gain a deeper understanding.

For our Models from Code logging, we’ll particularly use the PythonModel class within pyfunc. This class in the MLflow Python client library allows us to create and manage Python functions as MLflow models. It enables us to define a custom function that processes input data and returns predictions or results. This model can then be deployed, tracked, and shared using MLflow's features.

It seems to be exactly what we're looking for—we have some code that serves as our model, and we want to log it! That's why you'll soon see mlflow.pyfunc.PythonModel in our code example!

Now, each time we need to implement Models from Code, we create two separate Python files:

  1. The first contains our model code (let's call it model_code.py). This file contains a class that inherits from the mlflow.pyfunc.PythonModel class. The class we're defining contains our model logic. It could be our calls to OpenAI APIs, CLV (Customer Lifetime Value) model, or our causal discovery code. We'll see a very simple 101 example soon.

    📌 But wait! IMPORTANT:

    • Our model_code.py script needs to call (i,e; include) mlflow.models.set_model() to set the model, which is crucial for loading the model back using load_model() for inference. You will notice this in the example.
  2. The second file logs our class (that we defined in model_code.py). Think of it as the driver code; it can be either a notebook or a Python script (let's call it driver.py). In this file, we'll include the code that is responsible for logging our model code (essentially, providing the path to model_code.py) .

Then we can deploy our model. Later, when the serving environment is loaded, model_code.py is executed, and when a serving request comes in, PyFuncClass.predict() is called.

This figure gives a generic template of these two files.

Models from Code files

A 101 Example of Model from Code Logging :

Let’s consider a straightforward example: a simple function to calculate the area of a circle based on its diameter. With Models from Code, we can log this calculation as a model! I like to think of it as framing the calculation as a prediction problem, allowing us to write our model code with a predict method.

1. Our model_code.py file :

import mlflow
import math

class CircleAreaModel(mlflow.pyfunc.PythonModel):
def predict(self, context, model_input, params=None):
return [math.pi * (r ** 2) for r in model_input]

# It's important to call set_model() so it can be loaded for inference
# Also, note that it is set to an instance of the class, not the class itself.
mlflow.models.set_model(model=CircleAreaModel())

2. Our driver.py file :

This can be defined within a notebook as well. Here are its essential contents:

import mlflow

code_path = "model_code.py" # make sure that you put the correct path

with mlflow.start_run():
logged_model_info = mlflow.pyfunc.log_model(
python_model=code_path,
artifact_path="test_code_logging"
)

#We can proint some info about the logged model
print(f"MLflow Run: {logged_model_info.run_id}")
print(f"Model URI: {logged_model_info.model_uri}")

How that looks like on MLflow:

Executing the driver.py will start an MLflow run and log our model as code. The files can been as demonstrated below:

Models from Code files

Conclusion and Further Learning

I hope that by this point, I have fulfilled the promises I made earlier! You should now have a clearer understanding of What Models from Code is and how it differs from the popular Object-Based approach which logs models as serialized objects. You should also have a solid foundation of Why and when to use it, as well as an understanding of How to implement it through our general example.

As we mentioned in the introduction and throughout the post, there are various use cases where Models from Code can be beneficial. Our 101 example is just the beginning—there is much more to explore. Below is a list of code examples that you may find helpful:

  1. Logging models from code using Pyfunc log model API ( model code | driver code )
  2. Logging model from code using Langchain log model API ( model code | driver code )

· 22 min read
Michael Berk
MLflow maintainers

In this blog, we'll guide you through creating an AutoGen agent framework within an MLflow custom PyFunc. By combining MLflow with AutoGen's ability to create multi-agent frameworks, we are able to create scalable and stable GenAI applications.

Agent Frameworks

Agent frameworks enable autonomous agents to handle complex, multi-turn tasks by integrating discrete logic at each step. These frameworks are crucial for LLM-driven workflows, where agents manage dynamic interactions across multiple stages. Each agent operates based on specific logic, enabling precise task automation, decision-making, and coordination. This is ideal for applications like workflow orchestration, customer support, and multi-agent systems, where LLMs must interpret evolving context and respond accordingly.

· 8 min read
Michael Berk
MLflow maintainers

In this blog, we'll guide you through creating a LangGraph chatbot using MLflow. By combining MLflow with LangGraph's ability to create and manage cyclical graphs, you can create powerful stateful, multi-actor applications in a scalable fashion.

Throughout this post we will demonstrate how to leverage MLflow's capabilities to create a serializable and servable MLflow model which can easily be tracked, versioned, and deployed on a variety of servers. We'll be using the langchain flavor combined with MLflow's model from code feature.

What is LangGraph?

LangGraph is a library for building stateful, multi-actor applications with LLMs, used to create agent and multi-agent workflows. Compared to other LLM frameworks, it offers these core benefits:

  • Cycles and Branching: Implement loops and conditionals in your apps.
  • Persistence: Automatically save state after each step in the graph. Pause and resume the graph execution at any point to support error recovery, human-in-the-loop workflows, time travel and more.
  • Human-in-the-Loop: Interrupt graph execution to approve or edit next action planned by the agent.
  • Streaming Support: Stream outputs as they are produced by each node (including token streaming).
  • Integration with LangChain: LangGraph integrates seamlessly with LangChain.

LangGraph allows you to define flows that involve cycles, essential for most agentic architectures, differentiating it from DAG-based solutions. As a very low-level framework, it provides fine-grained control over both the flow and state of your application, crucial for creating reliable agents. Additionally, LangGraph includes built-in persistence, enabling advanced human-in-the-loop and memory features.

LangGraph is inspired by Pregel and Apache Beam. The public interface draws inspiration from NetworkX. LangGraph is built by LangChain Inc, the creators of LangChain, but can be used without LangChain.

For a full walkthrough, check out the LangGraph Quickstart and for more on the fundamentals of design with LangGraph, check out the conceptual guides.

1 - Setup

First, we must install the required dependencies. We will use OpenAI for our LLM in this example, but using LangChain with LangGraph makes it easy to substitute any alternative supported LLM or LLM provider.

%%capture
%pip install langchain_openai==0.2.0 langchain==0.3.0 langgraph==0.2.27
%pip install -U mlflow

Next, let's get our relevant secrets. getpass, as demonstrated in the LangGraph quickstart is a great way to insert your keys into an interactive jupyter environment.

import os

# Set required environment variables for authenticating to OpenAI
# Check additional MLflow tutorials for examples of authentication if needed
# https://mlflow.org/docs/latest/llms/openai/guide/index.html#direct-openai-service-usage
assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."

2 - Custom Utilities

While this is a demo, it's good practice to separate reusable utilities into a separate file/directory. Below we create three general utilities that theoretically would valuable when building additional MLflow + LangGraph implementations.

Note that we use the magic %%writefile command to create a new file in a jupyter notebook context. If you're running this outside of an interactive notebook, simply create the file below, omitting the %%writefile {FILE_NAME}.py line.

%%writefile langgraph_utils.py
# omit this line if directly creating this file; this command is purely for running within Jupyter

import os
from typing import Union
from langgraph.pregel.io import AddableValuesDict

def _langgraph_message_to_mlflow_message(
langgraph_message: AddableValuesDict,
) -> dict:
langgraph_type_to_mlflow_role = {
"human": "user",
"ai": "assistant",
"system": "system",
}

if type_clean := langgraph_type_to_mlflow_role.get(langgraph_message.type):
return {"role": type_clean, "content": langgraph_message.content}
else:
raise ValueError(f"Incorrect role specified: {langgraph_message.type}")


def get_most_recent_message(response: AddableValuesDict) -> dict:
most_recent_message = response.get("messages")[-1]
return _langgraph_message_to_mlflow_message(most_recent_message)["content"]


def increment_message_history(
response: AddableValuesDict, new_message: Union[dict, AddableValuesDict]
) -> list[dict]:
if isinstance(new_message, AddableValuesDict):
new_message = _langgraph_message_to_mlflow_message(new_message)

message_history = [
_langgraph_message_to_mlflow_message(message)
for message in response.get("messages")
]

return message_history + [new_message]

By the end of this step, you should see a new file in your current directory with the name langgraph_utils.py.

Note that it's best practice to add unit tests and properly organize your project into logically structured directories.

3 - Log the LangGraph Model

Great! Now that we have some reusable utilities located in ./langgraph_utils.py, we are ready to log the model with MLflow's official LangGraph flavor.

3.1 - Create our Model-From-Code File

Quickly, some background. MLflow looks to serialize model artifacts to the MLflow tracking server. Many popular ML packages don't have robust serialization and deserialization support, so MLflow looks to augment this functionality via the models from code feature. With models from code, we're able to leverage Python as the serialization format, instead of popular alternatives such as JSON or pkl. This opens up tons of flexibility and stability.

To create a Python file with models from code, we must perform the following steps:

  1. Create a new python file. Let's call it graph.py.
  2. Define our langgraph graph.
  3. Leverage mlflow.models.set_model to indicate to MLflow which object in the Python script is our model of interest.

That's it!

%%writefile graph.py
# omit this line if directly creating this file; this command is purely for running within Jupyter

from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.graph.state import CompiledStateGraph

import mlflow

import os
from typing import TypedDict, Annotated

def load_graph() -> CompiledStateGraph:
"""Create example chatbot from LangGraph Quickstart."""

assert "OPENAI_API_KEY" in os.environ, "Please set the OPENAI_API_KEY environment variable."

class State(TypedDict):
messages: Annotated[list, add_messages]

graph_builder = StateGraph(State)
llm = ChatOpenAI()

def chatbot(state: State):
return {"messages": [llm.invoke(state["messages"])]}

graph_builder.add_node("chatbot", chatbot)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("chatbot", END)
graph = graph_builder.compile()
return graph

# Set are model to be leveraged via model from code
mlflow.models.set_model(load_graph())

3.2 - Log with "Model from Code"

After creating this implementation, we can leverage the standard MLflow APIs to log the model.

import mlflow

with mlflow.start_run() as run_id:
model_info = mlflow.langchain.log_model(
lc_model="graph.py", # Path to our model Python file
artifact_path="langgraph",
)

model_uri = model_info.model_uri

4 - Use the Logged Model

Now that we have successfully logged a model, we can load it and leverage it for inference.

In the code below, we demonstrate that our chain has chatbot functionality!

import mlflow

# Custom utilities for handling chat history
from langgraph_utils import (
increment_message_history,
get_most_recent_message,
)

# Enable tracing
mlflow.set_experiment("Tracing example") # In Databricks, use an absolute path. Visit Databricks docs for more.
mlflow.langchain.autolog()

# Load the model
loaded_model = mlflow.langchain.load_model(model_uri)

# Show inference and message history functionality
print("-------- Message 1 -----------")
message = "What's my name?"
payload = {"messages": [{"role": "user", "content": message}]}
response = loaded_model.invoke(payload)

print(f"User: {message}")
print(f"Agent: {get_most_recent_message(response)}")

print("\n-------- Message 2 -----------")
message = "My name is Morpheus."
new_messages = increment_message_history(response, {"role": "user", "content": message})
payload = {"messages": new_messages}
response = loaded_model.invoke(payload)

print(f"User: {message}")
print(f"Agent: {get_most_recent_message(response)}")

print("\n-------- Message 3 -----------")
message = "What is my name?"
new_messages = increment_message_history(response, {"role": "user", "content": message})
payload = {"messages": new_messages}
response = loaded_model.invoke(payload)

print(f"User: {message}")
print(f"Agent: {get_most_recent_message(response)}")

Ouput:

-------- Message 1 -----------
User: What's my name?
Agent: I'm sorry, I cannot guess your name as I do not have access to that information. If you would like to share your name with me, feel free to do so.

-------- Message 2 -----------
User: My name is Morpheus.
Agent: Nice to meet you, Morpheus! How can I assist you today?

-------- Message 3 -----------
User: What is my name?
Agent: Your name is Morpheus.

4.1 - MLflow Tracing

Before concluding, let's demonstrate MLflow tracing.

MLflow Tracing is a feature that enhances LLM observability in your Generative AI (GenAI) applications by capturing detailed information about the execution of your application’s services. Tracing provides a way to record the inputs, outputs, and metadata associated with each intermediate step of a request, enabling you to easily pinpoint the source of bugs and unexpected behaviors.

Start the MLflow server as outlined in the tracking server docs. After entering the MLflow UI, we can see our experiment and corresponding traces.

MLflow UI Experiment Traces

As you can see, we've logged our traces and can easily see them by clicking our experiment of interest and the then the "Tracing" tab.

MLflow UI Trace

After clicking on one of the traces, we can now see run execution for a single query. Notice that we log inputs, outputs, and lots of great metadata such as usage and invocation parameters. As we scale our application both from a usage and complexity perspective, this thread-safe and highly-performant tracking system will ensure robust monitoring of the app.

5 - Summary

There are many logical extensions of the this tutorial, however the MLflow components can remain largely unchanged. Some examples include persisting chat history to a database, implementing a more complex langgraph object, productionizing this solution, and much more!

To summarize, here's what was covered in this tutorial:

  • Creating a simple LangGraph chain.
  • Leveraging MLflow model from code functionality to log our graph.
  • Loading the model via the standard MLflow APIs.
  • Leveraging MLflow tracing to view graph execution.

Happy coding!

· 4 min read
MLflow maintainers

We're excited to announce the release of a powerful new feature in MLflow: MLflow Tracing. This feature brings comprehensive instrumentation capabilities to your GenAI applications, enabling you to gain deep insights into the execution of your models and workflows, from simple chat interfaces to complex multi-stage Retrieval Augmented Generation (RAG) applications.

NOTE: MLflow Tracing has been released in MLflow 2.14.0 and is not available in previous versions.

Introducing MLflow Tracing

Tracing is a critical aspect of understanding and optimizing complex applications, especially in the realm of machine learning and artificial intelligence. With the release of MLflow Tracing, you can now easily capture, visualize, and analyze detailed execution traces of your GenAI applications. This new feature aims to provide greater visibility and control over your applications' performance and behavior, aiding in everything from fine-tuning to debugging.

What is MLflow Tracing?

MLflow Tracing offers a variety of methods to enable tracing in your applications:

  • Automated Tracing with LangChain: A fully automated integration with LangChain allows you to activate tracing simply by enabling mlflow.langchain.autolog().
  • Manual Trace Instrumentation with High-Level Fluent APIs: Use decorators, function wrappers, and context managers via the fluent API to add tracing functionality with minimal code modifications.
  • Low-Level Client APIs for Tracing: The MLflow client API provides a thread-safe way to handle trace implementations for fine-grained control of what and when data is recorded.

Getting Started with MLflow Tracing

LangChain Automatic Tracing

The easiest way to get started with MLflow Tracing is through the built-in integration with LangChain. By enabling autologging, traces are automatically logged to the active MLflow experiment when calling invocation APIs on chains. Here’s a quick example:

import os
from langchain.prompts import PromptTemplate
from langchain_openai import OpenAI
import mlflow

assert "OPENAI_API_KEY" in os.environ, "Please set your OPENAI_API_KEY environment variable."

mlflow.set_experiment("LangChain Tracing")
mlflow.langchain.autolog(log_models=True, log_input_examples=True)

llm = OpenAI(temperature=0.7, max_tokens=1000)
prompt_template = "Imagine you are {person}, and you are answering a question: {question}"
chain = prompt_template | llm

chain.invoke({"person": "Richard Feynman", "question": "Why should we colonize Mars?"})
chain.invoke({"person": "Linus Torvalds", "question": "Can I set everyone's access to sudo?"})

And this is what you will see after invoking the chains when navigating to the LangChain Tracing experiment in the MLflow UI:

Traces in UI

Fluent APIs for Manual Tracing

For more control, you can use MLflow’s fluent APIs to manually instrument your code. This approach allows you to capture detailed trace data with minimal changes to your existing code.

Trace Decorator

The trace decorator captures the inputs and outputs of a function:

import mlflow

mlflow.set_experiment("Tracing Demo")

@mlflow.trace
def some_function(x, y, z=2):
return x + (y - z)

some_function(2, 4)

Context Handler

The context handler is ideal for supplementing span information with additional data at the point of information generation:

import mlflow

@mlflow.trace
def first_func(x, y=2):
return x + y

@mlflow.trace
def second_func(a, b=3):
return a * b

def do_math(a, x, operation="add"):
with mlflow.start_span(name="Math") as span:
span.set_inputs({"a": a, "x": x})
span.set_attributes({"mode": operation})
first = first_func(x)
second = second_func(a)
result = first + second if operation == "add" else first - second
span.set_outputs({"result": result})
return result

do_math(8, 3, "add")

Comprehensive Tracing with Client APIs

For advanced use cases, the MLflow client API offers fine-grained control over trace management. These APIs allows you to create, manipulate, and retrieve traces programmatically, albeit with additional complexity throughout the implementation.

Starting and Managing Traces with the Client APIs

from mlflow import MlflowClient

client = MlflowClient()

# Start a new trace
root_span = client.start_trace("my_trace")
request_id = root_span.request_id

# Create a child span
child_span = client.start_span(
name="child_span",
request_id=request_id,
parent_id=root_span.span_id,
inputs={"input_key": "input_value"},
attributes={"attribute_key": "attribute_value"},
)

# End the child span
client.end_span(
request_id=child_span.request_id,
span_id=child_span.span_id,
outputs={"output_key": "output_value"},
attributes={"custom_attribute": "value"},
)

# End the root span (trace)
client.end_trace(
request_id=request_id,
outputs={"final_output_key": "final_output_value"},
attributes={"token_usage": "1174"},
)

Diving Deeper into Tracing

MLflow Tracing is designed to be flexible and powerful, supporting various use cases from simple function tracing to complex, asynchronous workflows.

To learn more about this feature, read the guide, review the API Docs and get started with the LangChain integration today!

Join Us on This Journey

The introduction of MLflow Tracing marks a significant milestone in our mission to provide comprehensive tools for managing machine learning workflows. We’re excited about the possibilities this new feature opens up and look forward to your feedback and contributions.

For those in our community with a passion for sharing knowledge, we invite you to collaborate. Whether it’s writing tutorials, sharing use-cases, or providing feedback, every contribution enriches the MLflow community.

Stay tuned for more updates, and as always, happy coding!