Source code for mlflow.artifacts

"""
APIs for interacting with artifacts in MLflow
"""

import json
import pathlib
import tempfile
from typing import Optional

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import BAD_REQUEST, INVALID_PARAMETER_VALUE
from mlflow.tracking import _get_store
from mlflow.tracking.artifact_utils import (
    _download_artifact_from_uri,
    _get_root_uri_and_artifact_path,
    add_databricks_profile_info_to_artifact_uri,
    get_artifact_repository,
)


[docs]def download_artifacts( artifact_uri: Optional[str] = None, run_id: Optional[str] = None, artifact_path: Optional[str] = None, dst_path: Optional[str] = None, tracking_uri: Optional[str] = None, ) -> str: """Download an artifact file or directory to a local directory. Args: artifact_uri: URI pointing to the artifacts, such as ``"runs:/500cf58bee2b40a4a82861cc31a617b1/my_model.pkl"``, ``"models:/my_model/Production"``, or ``"s3://my_bucket/my/file.txt"``. Exactly one of ``artifact_uri`` or ``run_id`` must be specified. run_id: ID of the MLflow Run containing the artifacts. Exactly one of ``run_id`` or ``artifact_uri`` must be specified. artifact_path: (For use with ``run_id``) If specified, a path relative to the MLflow Run's root directory containing the artifacts to download. dst_path: Path of the local filesystem destination directory to which to download the specified artifacts. If the directory does not exist, it is created. If unspecified, the artifacts are downloaded to a new uniquely-named directory on the local filesystem, unless the artifacts already exist on the local filesystem, in which case their local path is returned directly. tracking_uri: The tracking URI to be used when downloading artifacts. Returns: The location of the artifact file or directory on the local filesystem. """ if (run_id, artifact_uri).count(None) != 1: raise MlflowException( message="Exactly one of `run_id` or `artifact_uri` must be specified", error_code=INVALID_PARAMETER_VALUE, ) elif artifact_uri is not None and artifact_path is not None: raise MlflowException( message="`artifact_path` cannot be specified if `artifact_uri` is specified", error_code=INVALID_PARAMETER_VALUE, ) if dst_path is not None: pathlib.Path(dst_path).mkdir(exist_ok=True, parents=True) if artifact_uri is not None: return _download_artifact_from_uri(artifact_uri, output_path=dst_path) artifact_path = artifact_path if artifact_path is not None else "" store = _get_store(store_uri=tracking_uri) artifact_uri = store.get_run(run_id).info.artifact_uri artifact_repo = get_artifact_repository( add_databricks_profile_info_to_artifact_uri(artifact_uri, tracking_uri) ) return artifact_repo.download_artifacts(artifact_path, dst_path=dst_path)
[docs]def list_artifacts( artifact_uri: Optional[str] = None, run_id: Optional[str] = None, artifact_path: Optional[str] = None, tracking_uri: Optional[str] = None, ): """List artifacts at the specified URI. Args: artifact_uri: URI pointing to the artifacts, such as ``"runs:/500cf58bee2b40a4a82861cc31a617b1/my_model.pkl"``, ``"models:/my_model/Production"``, or ``"s3://my_bucket/my/file.txt"``. Exactly one of ``artifact_uri`` or ``run_id`` must be specified. run_id: ID of the MLflow Run containing the artifacts. Exactly one of ``run_id`` or ``artifact_uri`` must be specified. artifact_path: (For use with ``run_id``) If specified, a path relative to the MLflow Run's root directory containing the artifacts to list. tracking_uri: The tracking URI to be used when list artifacts. Returns: List of artifacts as FileInfo listed directly under path. """ if (run_id, artifact_uri).count(None) != 1: raise MlflowException.invalid_parameter_value( message="Exactly one of `run_id` or `artifact_uri` must be specified", ) elif artifact_uri is not None and artifact_path is not None: raise MlflowException.invalid_parameter_value( message="`artifact_path` cannot be specified if `artifact_uri` is specified", ) if artifact_uri is not None: root_uri, artifact_path = _get_root_uri_and_artifact_path(artifact_uri) return get_artifact_repository(artifact_uri=root_uri).list_artifacts(artifact_path) store = _get_store(store_uri=tracking_uri) artifact_uri = store.get_run(run_id).info.artifact_uri artifact_repo = get_artifact_repository( add_databricks_profile_info_to_artifact_uri(artifact_uri, tracking_uri) ) return artifact_repo.list_artifacts(artifact_path)
[docs]def load_text(artifact_uri: str) -> str: """Loads the artifact contents as a string. Args: artifact_uri: Artifact location. Returns: The contents of the artifact as a string. .. code-block:: python :caption: Example import mlflow with mlflow.start_run() as run: artifact_uri = run.info.artifact_uri mlflow.log_text("This is a sentence", "file.txt") file_content = mlflow.artifacts.load_text(artifact_uri + "/file.txt") print(file_content) .. code-block:: text :caption: Output This is a sentence """ with tempfile.TemporaryDirectory() as tmpdir: local_artifact = download_artifacts(artifact_uri, dst_path=tmpdir) with open(local_artifact) as local_artifact_fd: try: return str(local_artifact_fd.read()) except Exception: raise MlflowException("Unable to form a str object from file content", BAD_REQUEST)
[docs]def load_dict(artifact_uri: str) -> dict: """Loads the artifact contents as a dictionary. Args: artifact_uri: artifact location. Returns: A dictionary. .. code-block:: python :caption: Example import mlflow with mlflow.start_run() as run: artifact_uri = run.info.artifact_uri mlflow.log_dict({"mlflow-version": "0.28", "n_cores": "10"}, "config.json") config_json = mlflow.artifacts.load_dict(artifact_uri + "/config.json") print(config_json) .. code-block:: text :caption: Output {'mlflow-version': '0.28', 'n_cores': '10'} """ with tempfile.TemporaryDirectory() as tmpdir: local_artifact = download_artifacts(artifact_uri, dst_path=tmpdir) with open(local_artifact) as local_artifact_fd: try: return json.load(local_artifact_fd) except json.JSONDecodeError: raise MlflowException("Unable to form a JSON object from file content", BAD_REQUEST)
[docs]def load_image(artifact_uri: str): """Loads artifact contents as a ``PIL.Image.Image`` object Args: artifact_uri: Artifact location. Returns: A PIL.Image object. .. code-block:: python :caption: Example import mlflow from PIL import Image with mlflow.start_run() as run: image = Image.new("RGB", (100, 100)) artifact_uri = run.info.artifact_uri mlflow.log_image(image, "image.png") image = mlflow.artifacts.load_image(artifact_uri + "/image.png") print(image) .. code-block:: text :caption: Output <PIL.PngImagePlugin.PngImageFile image mode=RGB size=100x100 at 0x11D2FA3D0> """ try: from PIL import Image except ImportError as exc: raise ImportError( "`load_image` requires Pillow. Please install it via: pip install Pillow" ) from exc with tempfile.TemporaryDirectory() as tmpdir: local_artifact = download_artifacts(artifact_uri, dst_path=tmpdir) try: image_obj = Image.open(local_artifact) image_obj.load() return image_obj except Exception: raise MlflowException( "Unable to form a PIL Image object from file content", BAD_REQUEST )