Source code for mlflow.litellm

from contextlib import contextmanager
from threading import Thread

from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import (
    autologging_integration,
    safe_patch,
)
from mlflow.utils.databricks_utils import is_in_databricks_runtime

FLAVOR_NAME = "litellm"


[docs]@experimental def autolog( log_traces: bool = True, disable: bool = False, silent: bool = False, ): """ Enables (or disables) and configures autologging from LiteLLM to MLflow. Currently, MLflow only supports autologging for tracing. Args: log_traces: If ``True``, traces are logged for LiteLLM calls. If ``False``, no traces are collected during inference. Default to ``True``. disable: If ``True``, disables the LiteLLM autologging integration. If ``False``, enables the LiteLLM autologging integration. silent: If ``True``, suppress all event logs and warnings from MLflow during LiteLLM autologging. If ``False``, show all events and warnings. """ import litellm # This needs to be called before doing any safe-patching (otherwise safe-patch will be no-op). # TODO: since this implementation is inconsistent, explore a universal way to solve the issue. _autolog(log_traces=log_traces, disable=disable, silent=silent) if log_traces and not disable: litellm.success_callback = _append_mlflow_callbacks(litellm.success_callback) litellm.failure_callback = _append_mlflow_callbacks(litellm.failure_callback) if is_in_databricks_runtime(): # Patch main APIs e.g. completion to inject custom handling for threading. # By default, those API will start a new thread when calling log_success_event() # handler of the logging callbacks and never wait for it to finish. This is # problematic in Databricks notebook, because the inline trace UI display # assumes that the trace is generated synchronously. If the trace is generated # asynchronously, it will be displayed in different later cells. # To workaround this issue, we monkey-patch these APIs to wait for the logging # threads to finish before returning the result. # This is not required for OSS environment where we don't show inline trace UI. for func in [ "completion", "embedding", "text_completion", "image_generation", "transcription", "speech", ]: _patch_threading_in_function(litellm, func) # For streaming case, we need to patch the iterator because traces are generated # when consuming the generator, not when calling the main APIs. _patch_threading_in_function(litellm.utils.CustomStreamWrapper, "__next__") # NB: We don't need to patch async function because Databricks notebook waits # for the async task to finish before finishing the cell. else: litellm.success_callback = _remove_mlflow_callbacks(litellm.success_callback) litellm.failure_callback = _remove_mlflow_callbacks(litellm.failure_callback) # Callback also needs to be removed from 'callbacks' as litellm adds # success/failure callbacks to there as well. litellm.callbacks = _remove_mlflow_callbacks(litellm.callbacks)
# This is required by mlflow.autolog() autolog.integration_name = FLAVOR_NAME # NB: The @autologging_integration annotation must be applied here, and the callback injection # needs to happen outside the annotated function. This is because the annotated function is NOT # executed when disable=True is passed. This prevents us from removing our callback and patching # when autologging is turned off. @autologging_integration(FLAVOR_NAME) def _autolog( log_traces: bool, disable: bool = False, silent: bool = False, ): pass def _patch_threading_in_function(target, function_name: str): """ Apply the threading patch to a synchronous function. We capture the threads started by the function using the _patch_thread_start context manager, then join them to ensure they are finished before the notebook cell finishes executing. """ def _patch_fn(original, *args, **kwargs): with _patch_thread_start() as logging_threads: result = original(*args, **kwargs) for thread in logging_threads: thread.join() return result safe_patch(FLAVOR_NAME, target, function_name, _patch_fn) @contextmanager def _patch_thread_start(): """ A context manager to collect threads started for logging handlers. This is done by monkey-patching the start() method of threading.Thread. Note that failure handlers are executed synchronously, so we don't need to patch them. """ original = Thread.start logging_threads = [] def patched_thread(self, *args, **kwargs): target = getattr(self, "_target", None) # success_handler is for normal request, and run_success_... is for streaming # - https://github.com/BerriAI/litellm/blob/4f8a3fd4cfc20cf43b38379928b41c2691c85d36/litellm/utils.py#L946 # - https://github.com/BerriAI/litellm/blob/4f8a3fd4cfc20cf43b38379928b41c2691c85d36/litellm/utils.py#L7526 if target and target.__name__ in [ "success_handler", "run_success_logging_and_cache_storage", ]: logging_threads.append(self) return original(self, *args, **kwargs) Thread.start = patched_thread try: yield logging_threads finally: Thread.start = original def _append_mlflow_callbacks(callbacks): if not any(cb == "mlflow" for cb in callbacks): return callbacks + ["mlflow"] return callbacks def _remove_mlflow_callbacks(callbacks): try: from litellm.integrations.mlflow import MlflowLogger return [cb for cb in callbacks if not (cb == "mlflow" or isinstance(cb, MlflowLogger))] except ImportError: return callbacks