Source code for mlflow.models.dependencies_schemas

import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional

from mlflow.utils.annotations import experimental

_logger = logging.getLogger(__name__)


class DependenciesSchemasType(Enum):
    """
    Enum to define the different types of dependencies schemas for the model.
    """

    RETRIEVERS = "retrievers"


[docs]@experimental def set_retriever_schema( *, primary_key: str, text_column: str, doc_uri: Optional[str] = None, other_columns: Optional[List[str]] = None, name: Optional[str] = "retriever", ): """ After defining your vector store in a Python file or notebook, call set_retriever_schema() so that, when MLflow retrieves documents during model inference, MLflow can interpret the fields in each retrieved document and determine which fields correspond to the document text, document URI, etc. Args: primary_key: The primary key of the retriever or vector index. text_column: The name of the text column to use for the embeddings. doc_uri: The name of the column that contains the document URI. other_columns: A list of other columns that are part of the vector index that need to be retrieved during trace logging. name: The name of the retriever tool or vector store index. .. code-block:: Python :caption: Example from mlflow.models import set_retriever_schema set_retriever_schema( primary_key="chunk_id", text_column="chunk_text", doc_uri="doc_uri", other_columns=["title"], ) """ retriever_schemas = globals().get(DependenciesSchemasType.RETRIEVERS.value, []) # Check if a retriever schema with the same name already exists existing_schema = next((schema for schema in retriever_schemas if schema["name"] == name), None) if existing_schema is not None: # Compare all relevant fields if ( existing_schema["primary_key"] == primary_key and existing_schema["text_column"] == text_column and existing_schema["doc_uri"] == doc_uri and existing_schema["other_columns"] == (other_columns or []) ): # No difference, no need to warn or update return else: # Differences found, issue a warning _logger.warning( f"A retriever schema with the name '{name}' already exists. " "Overriding the existing schema." ) # Override the fields of the existing schema existing_schema["primary_key"] = primary_key existing_schema["text_column"] = text_column existing_schema["doc_uri"] = doc_uri existing_schema["other_columns"] = other_columns or [] else: retriever_schemas.append( { "primary_key": primary_key, "text_column": text_column, "doc_uri": doc_uri, "other_columns": other_columns or [], "name": name, } ) globals()[DependenciesSchemasType.RETRIEVERS.value] = retriever_schemas
def _get_retriever_schema(): """ Get the vector search schema defined by the user. Returns: VectorSearchIndex: The vector search index schema. """ retriever_schemas = globals().get(DependenciesSchemasType.RETRIEVERS.value, []) if not retriever_schemas: return [] return [ RetrieverSchema( name=retriever.get("name"), primary_key=retriever.get("primary_key"), text_column=retriever.get("text_column"), doc_uri=retriever.get("doc_uri"), other_columns=retriever.get("other_columns"), ) for retriever in retriever_schemas ] def _clear_retriever_schema(): """ Clear the vector search schema defined by the user. """ globals().pop(DependenciesSchemasType.RETRIEVERS.value, None) def _clear_dependencies_schemas(): """ Clear all the dependencies schema defined by the user. """ # Clear the vector search schema _clear_retriever_schema() @contextmanager def _get_dependencies_schemas(): dependencies_schemas = DependenciesSchemas(retriever_schemas=_get_retriever_schema()) try: yield dependencies_schemas finally: _clear_dependencies_schemas() @dataclass class Schema(ABC): """ Base class for defining the resources needed to serve a model. Args: type (ResourceType): The type of the schema. """ type: DependenciesSchemasType @abstractmethod def to_dict(self): """ Convert the resource to a dictionary. Subclasses must implement this method. """ @classmethod @abstractmethod def from_dict(cls, data: Dict[str, str]): """ Convert the dictionary to a Resource. Subclasses must implement this method. """ @dataclass class RetrieverSchema(Schema): """ Define vector search index resource to serve a model. Args: name (str): The name of the vector search index schema. primary_key (str): The primary key for the index. text_column (str): The main text column for the index. doc_uri (Optional[str]): The document URI for the index. other_columns (Optional[List[str]]): Additional columns in the index. """ def __init__( self, name: str, primary_key: str, text_column: str, doc_uri: Optional[str] = None, other_columns: Optional[List[str]] = None, ): super().__init__(type=DependenciesSchemasType.RETRIEVERS) self.name = name self.primary_key = primary_key self.text_column = text_column self.doc_uri = doc_uri self.other_columns = other_columns or [] def to_dict(self): return { self.type.value: [ { "name": self.name, "primary_key": self.primary_key, "text_column": self.text_column, "doc_uri": self.doc_uri, "other_columns": self.other_columns, } ] } @classmethod def from_dict(cls, data: Dict[str, str]): return cls( name=data["name"], primary_key=data["primary_key"], text_column=data["text_column"], doc_uri=data.get("doc_uri"), other_columns=data.get("other_columns", []), ) @dataclass class DependenciesSchemas: retriever_schemas: List[RetrieverSchema] = field(default_factory=list) def to_dict(self) -> Dict[str, Dict[DependenciesSchemasType, List[Dict]]]: if not self.retriever_schemas: return None return { "dependencies_schemas": { DependenciesSchemasType.RETRIEVERS.value: [ index.to_dict()[DependenciesSchemasType.RETRIEVERS.value][0] for index in self.retriever_schemas ], } }