Source code for mlflow.deployments.openai

import os

from mlflow.deployments import BaseDeploymentClient
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils.openai_utils import (
    _OAITokenHolder,
    _OpenAIApiConfig,
    _OpenAIEnvVar,
)
from mlflow.utils.rest_utils import augmented_raise_for_status


[docs]class OpenAIDeploymentClient(BaseDeploymentClient): """ Client for interacting with OpenAI endpoints. Example: First, set up credentials for authentication: .. code-block:: bash export OPENAI_API_KEY=... .. seealso:: See https://mlflow.org/docs/latest/python_api/openai/index.html for other authentication methods. Then, create a deployment client and use it to interact with OpenAI endpoints: .. code-block:: python from mlflow.deployments import get_deploy_client client = get_deploy_client("openai") client.predict( endpoint="gpt-4o-mini", inputs={ "messages": [ {"role": "user", "content": "Hello!"}, ], }, ) """
[docs] def create_deployment(self, name, model_uri, flavor=None, config=None, endpoint=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def update_deployment(self, name, model_uri=None, flavor=None, config=None, endpoint=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def delete_deployment(self, name, config=None, endpoint=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def list_deployments(self, endpoint=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def get_deployment(self, name, endpoint=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def predict(self, deployment_name=None, inputs=None, endpoint=None): """Query an OpenAI endpoint. See https://platform.openai.com/docs/api-reference for more information. Args: deployment_name: Unused. inputs: A dictionary containing the model inputs to query. endpoint: The name of the endpoint to query. Returns: A dictionary containing the model outputs. """ _check_openai_key() api_config = _get_api_config_without_openai_dep() api_token = _OAITokenHolder(api_config.api_type) api_token.refresh() if api_config.api_type in ("azure", "azure_ad", "azuread"): from openai import AzureOpenAI client = AzureOpenAI( api_key=api_token.token, azure_endpoint=api_config.api_base, api_version=api_config.api_version, azure_deployment=api_config.deployment_id, max_retries=api_config.max_retries, timeout=api_config.timeout, ) else: from openai import OpenAI client = OpenAI( api_key=api_token.token, base_url=api_config.api_base, max_retries=api_config.max_retries, timeout=api_config.timeout, ) return client.chat.completions.create( messages=inputs["messages"], model=endpoint ).model_dump()
[docs] def create_endpoint(self, name, config=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def update_endpoint(self, endpoint, config=None): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def delete_endpoint(self, endpoint): """ .. warning:: This method is not implemented for `OpenAIDeploymentClient`. """ raise NotImplementedError
[docs] def list_endpoints(self): """ List the currently available models. """ _check_openai_key() api_config = _get_api_config_without_openai_dep() import requests if api_config.api_type in ("azure", "azure_ad", "azuread"): raise NotImplementedError( "List endpoints is not implemented for Azure OpenAI API", ) else: api_key = os.environ["OPENAI_API_KEY"] request_header = {"Authorization": f"Bearer {api_key}"} response = requests.get( "https://api.openai.com/v1/models", headers=request_header, ) augmented_raise_for_status(response) return response.json()
[docs] def get_endpoint(self, endpoint): """ Get information about a specific model. """ _check_openai_key() api_config = _get_api_config_without_openai_dep() import requests if api_config.api_type in ("azure", "azure_ad", "azuread"): raise NotImplementedError( "Get endpoint is not implemented for Azure OpenAI API", ) else: api_key = os.environ["OPENAI_API_KEY"] request_header = {"Authorization": f"Bearer {api_key}"} response = requests.get( f"https://api.openai.com/v1/models/{endpoint}", headers=request_header, ) augmented_raise_for_status(response) return response.json()
def run_local(name, model_uri, flavor=None, config=None): pass def target_help(): pass def _get_api_config_without_openai_dep() -> _OpenAIApiConfig: """ Gets the parameters and configuration of the OpenAI API connected to. """ api_type = os.getenv(_OpenAIEnvVar.OPENAI_API_TYPE.value) api_version = os.getenv(_OpenAIEnvVar.OPENAI_API_VERSION.value) api_base = os.getenv(_OpenAIEnvVar.OPENAI_API_BASE.value, None) deployment_id = os.getenv(_OpenAIEnvVar.OPENAI_DEPLOYMENT_NAME.value, None) if api_type in ("azure", "azure_ad", "azuread"): batch_size = 16 max_tokens_per_minute = 60_000 else: # The maximum batch size is 2048: # https://github.com/openai/openai-python/blob/b82a3f7e4c462a8a10fa445193301a3cefef9a4a/openai/embeddings_utils.py#L43 # We use a smaller batch size to be safe. batch_size = 1024 max_tokens_per_minute = 90_000 return _OpenAIApiConfig( api_type=api_type, batch_size=batch_size, max_requests_per_minute=3_500, max_tokens_per_minute=max_tokens_per_minute, api_base=api_base, api_version=api_version, deployment_id=deployment_id, ) def _check_openai_key(): if "OPENAI_API_KEY" not in os.environ: raise MlflowException( "OPENAI_API_KEY environment variable not set", error_code=INVALID_PARAMETER_VALUE, )