Source code for mlflow.models.model_config

import os
from typing import Any, Optional, Union

import yaml

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE

__mlflow_model_config__ = None


[docs]class ModelConfig: """ ModelConfig used in code to read a YAML configuration file or a dictionary. Args: development_config: Path to the YAML configuration file or a dictionary containing the configuration. If the configuration is not provided, an error is raised .. code-block:: python :caption: Example usage in model code from mlflow.models import ModelConfig # Load the configuration from a dictionary config = ModelConfig(development_config={"key1": "value1"}) print(config.get("key1")) .. code-block:: yaml :caption: yaml file for model configuration key1: value1 another_key: - value2 - value3 .. code-block:: python :caption: Example yaml usage in model code from mlflow.models import ModelConfig # Load the configuration from a file config = ModelConfig(development_config="config.yaml") print(config.get("key1")) When invoking the ModelConfig locally in a model file, development_config can be passed in which would be used as configuration for the model. .. code-block:: python :caption: Example to use ModelConfig when logging model as code: agent.py import mlflow from mlflow.models import ModelConfig config = ModelConfig(development_config={"key1": "value1"}) class TestModel(mlflow.pyfunc.PythonModel): def predict(self, context, model_input, params=None): return config.get("key1") mlflow.models.set_model(TestModel()) But this development_config configuration file will be overridden when logging a model. When no model_config is passed in while logging the model, an error will be raised when trying to load the model using ModelConfig. Note: development_config is not used when logging the model. .. code-block:: python :caption: Example to use agent.py to log the model: deploy.py model_config = {"key1": "value2"} with mlflow.start_run(): model_info = mlflow.pyfunc.log_model( artifact_path="model", python_model="agent.py", model_config=model_config ) loaded_model = mlflow.pyfunc.load_model(model_info.model_uri) # This will print "value2" as the model_config passed in while logging the model print(loaded_model.predict(None)) """ def __init__(self, *, development_config: Optional[Union[str, dict[str, Any]]] = None): config = globals().get("__mlflow_model_config__", None) # Here mlflow_model_config have 3 states: # 1. None, this means if the mlflow_model_config is None, use development_config if # available # 2. "", Empty string, this means the users explicitly didn't set the model config # while logging the model so if ModelConfig is used, it should throw an error # 3. A valid path, this means the users have set the model config while logging the # model so use that path if config is not None: self.config = config else: self.config = development_config if not self.config: raise FileNotFoundError( "Config file is not provided which is needed to load the model. " "Please provide a valid path." ) if not isinstance(self.config, dict) and not os.path.isfile(self.config): raise FileNotFoundError(f"Config file '{self.config}' not found.") def _read_config(self): """Reads the YAML configuration file and returns its contents. Raises: FileNotFoundError: If the configuration file does not exist. yaml.YAMLError: If there is an error parsing the YAML content. Returns: dict or None: The content of the YAML file as a dictionary, or None if the config path is not set. """ if isinstance(self.config, dict): return self.config with open(self.config) as file: try: return yaml.safe_load(file) except yaml.YAMLError as e: raise MlflowException( f"Error parsing YAML file: {e}", error_code=INVALID_PARAMETER_VALUE )
[docs] def to_dict(self): """Returns the configuration as a dictionary.""" return self._read_config()
[docs] def get(self, key): """Gets the value of a top-level parameter in the configuration.""" config_data = self._read_config() if config_data and key in config_data: return config_data[key] else: raise KeyError(f"Key '{key}' not found in configuration: {config_data}.")
def _set_model_config(model_config): globals()["__mlflow_model_config__"] = model_config