from typing import TYPE_CHECKING, Any, Optional
import requests
from mlflow import MlflowException
from mlflow.deployments import BaseDeploymentClient
from mlflow.deployments.constants import (
MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES,
)
from mlflow.deployments.server.constants import (
MLFLOW_DEPLOYMENTS_CRUD_ENDPOINT_BASE,
MLFLOW_DEPLOYMENTS_ENDPOINTS_BASE,
MLFLOW_DEPLOYMENTS_QUERY_SUFFIX,
)
from mlflow.deployments.utils import resolve_endpoint_url
from mlflow.environment_variables import (
MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT,
MLFLOW_HTTP_REQUEST_TIMEOUT,
)
from mlflow.protos.databricks_pb2 import BAD_REQUEST
from mlflow.store.entities.paged_list import PagedList
from mlflow.utils.annotations import experimental
from mlflow.utils.credentials import get_default_host_creds
from mlflow.utils.rest_utils import augmented_raise_for_status, http_request
from mlflow.utils.uri import join_paths
if TYPE_CHECKING:
from mlflow.deployments.server.config import Endpoint
[docs]@experimental
class MlflowDeploymentClient(BaseDeploymentClient):
"""
Client for interacting with the MLflow AI Gateway.
Example:
First, start the MLflow AI Gateway:
.. code-block:: bash
mlflow gateway start --config-path path/to/config.yaml
Then, create a client and use it to interact with the server:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("http://localhost:5000")
endpoints = client.list_endpoints()
assert [e.dict() for e in endpoints] == [
{
"name": "chat",
"endpoint_type": "llm/v1/chat",
"model": {"name": "gpt-4o-mini", "provider": "openai"},
"endpoint_url": "http://localhost:5000/gateway/chat/invocations",
},
]
"""
[docs] def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def delete_deployment(self, name, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def list_deployments(self, endpoint=None):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def get_deployment(self, name, endpoint=None):
"""
.. warning::
This method is not implemented for `MLflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def create_endpoint(self, name, config=None):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def update_endpoint(self, endpoint, config=None):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
[docs] def delete_endpoint(self, endpoint):
"""
.. warning::
This method is not implemented for `MlflowDeploymentClient`.
"""
raise NotImplementedError
def _call_endpoint(
self,
method: str,
route: str,
json_body: Optional[str] = None,
timeout: Optional[int] = None,
):
call_kwargs = {}
if method.lower() == "get":
call_kwargs["params"] = json_body
else:
call_kwargs["json"] = json_body
response = http_request(
host_creds=get_default_host_creds(self.target_uri),
endpoint=route,
method=method,
timeout=MLFLOW_HTTP_REQUEST_TIMEOUT.get() if timeout is None else timeout,
retry_codes=MLFLOW_DEPLOYMENT_CLIENT_REQUEST_RETRY_CODES,
raise_on_status=False,
**call_kwargs,
)
augmented_raise_for_status(response)
return response.json()
[docs] @experimental
def get_endpoint(self, endpoint) -> "Endpoint":
"""
Gets a specified endpoint configured for the MLflow AI Gateway.
Args:
endpoint: The name of the endpoint to retrieve.
Returns:
An `Endpoint` object representing the endpoint.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("http://localhost:5000")
endpoint = client.get_endpoint(endpoint="chat")
assert endpoint.dict() == {
"name": "chat",
"endpoint_type": "llm/v1/chat",
"model": {"name": "gpt-4o-mini", "provider": "openai"},
"endpoint_url": "http://localhost:5000/gateway/chat/invocations",
}
"""
# Delayed import to avoid importing mlflow.gateway in the module scope
from mlflow.deployments.server.config import Endpoint
route = join_paths(MLFLOW_DEPLOYMENTS_CRUD_ENDPOINT_BASE, endpoint)
response = self._call_endpoint("GET", route)
return Endpoint(
**{
**response,
"endpoint_url": resolve_endpoint_url(self.target_uri, response["endpoint_url"]),
}
)
def _list_endpoints(self, page_token=None) -> "PagedList[Endpoint]":
# Delayed import to avoid importing mlflow.gateway in the module scope
from mlflow.deployments.server.config import Endpoint
params = None if page_token is None else {"page_token": page_token}
response_json = self._call_endpoint(
"GET", MLFLOW_DEPLOYMENTS_CRUD_ENDPOINT_BASE, json_body=params
)
routes = [
Endpoint(
**{
**resp,
"endpoint_url": resolve_endpoint_url(
self.target_uri,
resp["endpoint_url"],
),
}
)
for resp in response_json.get("endpoints", [])
]
next_page_token = response_json.get("next_page_token")
return PagedList(routes, next_page_token)
[docs] @experimental
def list_endpoints(self) -> "list[Endpoint]":
"""
List endpoints configured for the MLflow AI Gateway.
Returns:
A list of ``Endpoint`` objects.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("http://localhost:5000")
endpoints = client.list_endpoints()
assert [e.dict() for e in endpoints] == [
{
"name": "chat",
"endpoint_type": "llm/v1/chat",
"model": {"name": "gpt-4o-mini", "provider": "openai"},
"endpoint_url": "http://localhost:5000/gateway/chat/invocations",
},
]
"""
endpoints = []
next_page_token = None
while True:
page = self._list_endpoints(next_page_token)
endpoints.extend(page)
next_page_token = page.token
if next_page_token is None:
break
return endpoints
[docs] @experimental
def predict(self, deployment_name=None, inputs=None, endpoint=None) -> dict[str, Any]:
"""
Submit a query to a configured provider endpoint.
Args:
deployment_name: Unused.
inputs: The inputs to the query, as a dictionary.
endpoint: The name of the endpoint to query.
Returns:
A dictionary containing the response from the endpoint.
Example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("http://localhost:5000")
response = client.predict(
endpoint="chat",
inputs={"messages": [{"role": "user", "content": "Hello"}]},
)
assert response == {
"id": "chatcmpl-8OLoQuaeJSLybq3NBoe0w5eyqjGb9",
"object": "chat.completion",
"created": 1700814410,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I assist you today?",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 9,
"total_tokens": 18,
},
}
Additional parameters that are valid for a given provider and endpoint configuration can be
included with the request as shown below, using an openai completions endpoint request as
an example:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("http://localhost:5000")
client.predict(
endpoint="completions",
inputs={
"prompt": "Hello!",
"temperature": 0.3,
"max_tokens": 500,
},
)
"""
query_route = join_paths(
MLFLOW_DEPLOYMENTS_ENDPOINTS_BASE, endpoint, MLFLOW_DEPLOYMENTS_QUERY_SUFFIX
)
try:
return self._call_endpoint(
"POST", query_route, inputs, MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT.get()
)
except MlflowException as e:
if isinstance(e.__cause__, requests.exceptions.Timeout):
raise MlflowException(
message=(
"The provider has timed out while generating a response to your "
"query. Please evaluate the available parameters for the query "
"that you are submitting. Some parameter values and inputs can "
"increase the computation time beyond the allowable route "
f"timeout of {MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT} "
"seconds."
),
error_code=BAD_REQUEST,
)
raise e
def run_local(name, model_uri, flavor=None, config=None):
pass
def target_help():
pass