Source code for mlflow.models.dependencies_schemas

import json
import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Optional

from mlflow.utils.annotations import experimental

if TYPE_CHECKING:
    from mlflow.models.model import Model

_logger = logging.getLogger(__name__)


class DependenciesSchemasType(Enum):
    """
    Enum to define the different types of dependencies schemas for the model.
    """

    RETRIEVERS = "retrievers"


[docs]@experimental def set_retriever_schema( *, primary_key: str, text_column: str, doc_uri: Optional[str] = None, other_columns: Optional[list[str]] = None, name: Optional[str] = "retriever", ): """ After defining your vector store in a Python file or notebook, call set_retriever_schema() so that, when MLflow retrieves documents during model inference, MLflow can interpret the fields in each retrieved document and determine which fields correspond to the document text, document URI, etc. Args: primary_key: The primary key of the retriever or vector index. text_column: The name of the text column to use for the embeddings. doc_uri: The name of the column that contains the document URI. other_columns: A list of other columns that are part of the vector index that need to be retrieved during trace logging. name: The name of the retriever tool or vector store index. .. code-block:: Python :caption: Example from mlflow.models import set_retriever_schema set_retriever_schema( primary_key="chunk_id", text_column="chunk_text", doc_uri="doc_uri", other_columns=["title"], ) """ retriever_schemas = globals().get(DependenciesSchemasType.RETRIEVERS.value, []) # Check if a retriever schema with the same name already exists existing_schema = next((schema for schema in retriever_schemas if schema["name"] == name), None) if existing_schema is not None: # Compare all relevant fields if ( existing_schema["primary_key"] == primary_key and existing_schema["text_column"] == text_column and existing_schema["doc_uri"] == doc_uri and existing_schema["other_columns"] == (other_columns or []) ): # No difference, no need to warn or update return else: # Differences found, issue a warning _logger.warning( f"A retriever schema with the name '{name}' already exists. " "Overriding the existing schema." ) # Override the fields of the existing schema existing_schema["primary_key"] = primary_key existing_schema["text_column"] = text_column existing_schema["doc_uri"] = doc_uri existing_schema["other_columns"] = other_columns or [] else: retriever_schemas.append( { "primary_key": primary_key, "text_column": text_column, "doc_uri": doc_uri, "other_columns": other_columns or [], "name": name, } ) globals()[DependenciesSchemasType.RETRIEVERS.value] = retriever_schemas
def _get_retriever_schema(): """ Get the vector search schema defined by the user. Returns: VectorSearchIndex: The vector search index schema. """ retriever_schemas = globals().get(DependenciesSchemasType.RETRIEVERS.value, []) if not retriever_schemas: return [] return [ RetrieverSchema( name=retriever.get("name"), primary_key=retriever.get("primary_key"), text_column=retriever.get("text_column"), doc_uri=retriever.get("doc_uri"), other_columns=retriever.get("other_columns"), ) for retriever in retriever_schemas ] def _clear_retriever_schema(): """ Clear the vector search schema defined by the user. """ globals().pop(DependenciesSchemasType.RETRIEVERS.value, None) def _clear_dependencies_schemas(): """ Clear all the dependencies schema defined by the user. """ # Clear the vector search schema _clear_retriever_schema() @contextmanager def _get_dependencies_schemas(): dependencies_schemas = DependenciesSchemas(retriever_schemas=_get_retriever_schema()) try: yield dependencies_schemas finally: _clear_dependencies_schemas() def _get_dependencies_schema_from_model(model: "Model") -> Optional[dict]: """ Get the dependencies schema from the logged model metadata. `dependencies_schemas` is a dictionary that defines the dependencies schemas, such as the retriever schemas. This code is now only useful for Databricks integration. """ if model.metadata and "dependencies_schemas" in model.metadata: dependencies_schemas = model.metadata["dependencies_schemas"] return { "dependencies_schemas": { dependency: json.dumps(schema) for dependency, schema in dependencies_schemas.items() } } return None @dataclass class Schema(ABC): """ Base class for defining the resources needed to serve a model. Args: type (ResourceType): The type of the schema. """ type: DependenciesSchemasType @abstractmethod def to_dict(self): """ Convert the resource to a dictionary. Subclasses must implement this method. """ @classmethod @abstractmethod def from_dict(cls, data: dict[str, str]): """ Convert the dictionary to a Resource. Subclasses must implement this method. """ @dataclass class RetrieverSchema(Schema): """ Define vector search index resource to serve a model. Args: name (str): The name of the vector search index schema. primary_key (str): The primary key for the index. text_column (str): The main text column for the index. doc_uri (Optional[str]): The document URI for the index. other_columns (Optional[List[str]]): Additional columns in the index. """ def __init__( self, name: str, primary_key: str, text_column: str, doc_uri: Optional[str] = None, other_columns: Optional[list[str]] = None, ): super().__init__(type=DependenciesSchemasType.RETRIEVERS) self.name = name self.primary_key = primary_key self.text_column = text_column self.doc_uri = doc_uri self.other_columns = other_columns or [] def to_dict(self): return { self.type.value: [ { "name": self.name, "primary_key": self.primary_key, "text_column": self.text_column, "doc_uri": self.doc_uri, "other_columns": self.other_columns, } ] } @classmethod def from_dict(cls, data: dict[str, str]): return cls( name=data["name"], primary_key=data["primary_key"], text_column=data["text_column"], doc_uri=data.get("doc_uri"), other_columns=data.get("other_columns", []), ) @dataclass class DependenciesSchemas: retriever_schemas: list[RetrieverSchema] = field(default_factory=list) def to_dict(self) -> dict[str, dict[DependenciesSchemasType, list[dict]]]: if not self.retriever_schemas: return None return { "dependencies_schemas": { DependenciesSchemasType.RETRIEVERS.value: [ index.to_dict()[DependenciesSchemasType.RETRIEVERS.value][0] for index in self.retriever_schemas ], } }