import os
from mlflow.deployments import BaseDeploymentClient
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils.openai_utils import (
_OAITokenHolder,
_OpenAIApiConfig,
_OpenAIEnvVar,
)
from mlflow.utils.rest_utils import augmented_raise_for_status
[docs]class OpenAIDeploymentClient(BaseDeploymentClient):
"""
Client for interacting with OpenAI endpoints.
Example:
First, set up credentials for authentication:
.. code-block:: bash
export OPENAI_API_KEY=...
.. seealso::
See https://mlflow.org/docs/latest/python_api/openai/index.html for other authentication
methods.
Then, create a deployment client and use it to interact with OpenAI endpoints:
.. code-block:: python
from mlflow.deployments import get_deploy_client
client = get_deploy_client("openai")
client.predict(
endpoint="gpt-4o-mini",
inputs={
"messages": [
{"role": "user", "content": "Hello!"},
],
},
)
"""
[docs] def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def delete_deployment(self, name, config=None, endpoint=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def list_deployments(self, endpoint=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def get_deployment(self, name, endpoint=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def predict(self, deployment_name=None, inputs=None, endpoint=None):
"""Query an OpenAI endpoint.
See https://platform.openai.com/docs/api-reference for more information.
Args:
deployment_name: Unused.
inputs: A dictionary containing the model inputs to query.
endpoint: The name of the endpoint to query.
Returns:
A dictionary containing the model outputs.
"""
_check_openai_key()
api_config = _get_api_config_without_openai_dep()
api_token = _OAITokenHolder(api_config.api_type)
api_token.refresh()
if api_config.api_type in ("azure", "azure_ad", "azuread"):
from openai import AzureOpenAI
client = AzureOpenAI(
api_key=api_token.token,
azure_endpoint=api_config.api_base,
api_version=api_config.api_version,
azure_deployment=api_config.deployment_id,
max_retries=api_config.max_retries,
timeout=api_config.timeout,
)
else:
from openai import OpenAI
client = OpenAI(
api_key=api_token.token,
base_url=api_config.api_base,
max_retries=api_config.max_retries,
timeout=api_config.timeout,
)
return client.chat.completions.create(
messages=inputs["messages"], model=endpoint
).model_dump()
[docs] def create_endpoint(self, name, config=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def update_endpoint(self, endpoint, config=None):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def delete_endpoint(self, endpoint):
"""
.. warning::
This method is not implemented for `OpenAIDeploymentClient`.
"""
raise NotImplementedError
[docs] def list_endpoints(self):
"""
List the currently available models.
"""
_check_openai_key()
api_config = _get_api_config_without_openai_dep()
import requests
if api_config.api_type in ("azure", "azure_ad", "azuread"):
raise NotImplementedError(
"List endpoints is not implemented for Azure OpenAI API",
)
else:
api_key = os.environ["OPENAI_API_KEY"]
request_header = {"Authorization": f"Bearer {api_key}"}
response = requests.get(
"https://api.openai.com/v1/models",
headers=request_header,
)
augmented_raise_for_status(response)
return response.json()
[docs] def get_endpoint(self, endpoint):
"""
Get information about a specific model.
"""
_check_openai_key()
api_config = _get_api_config_without_openai_dep()
import requests
if api_config.api_type in ("azure", "azure_ad", "azuread"):
raise NotImplementedError(
"Get endpoint is not implemented for Azure OpenAI API",
)
else:
api_key = os.environ["OPENAI_API_KEY"]
request_header = {"Authorization": f"Bearer {api_key}"}
response = requests.get(
f"https://api.openai.com/v1/models/{endpoint}",
headers=request_header,
)
augmented_raise_for_status(response)
return response.json()
def run_local(name, model_uri, flavor=None, config=None):
pass
def target_help():
pass
def _get_api_config_without_openai_dep() -> _OpenAIApiConfig:
"""
Gets the parameters and configuration of the OpenAI API connected to.
"""
api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value)
api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value)
api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value, None)
deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None)
if api_type in ("azure", "azure_ad", "azuread"):
batch_size = 16
max_tokens_per_minute = 60_000
else:
# The maximum batch size is 2048:
# https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43
# We use a smaller batch size to be safe.
batch_size = 1024
max_tokens_per_minute = 90_000
return _OpenAIApiConfig(
api_type=api_type,
batch_size=batch_size,
max_requests_per_minute=3_500,
max_tokens_per_minute=max_tokens_per_minute,
api_base=api_base,
api_version=api_version,
deployment_id=deployment_id,
)
def _check_openai_key():
if "OPENAI_API_KEY" not in os.environ:
raise MlflowException(
"OPENAI_API_KEY environment variable not set",
error_code=INVALID_PARAMETER_VALUE,
)