Source code for mlflow.data.http_dataset_source

import os
import re
from typing import Any
from urllib.parse import urlparse

from mlflow.data.dataset_source import DatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.utils.file_utils import create_tmp_dir
from mlflow.utils.rest_utils import augmented_raise_for_status, cloud_storage_http_request


def _is_path(filename: str) -> bool:
    """
    Return True if `filename` is a path, False otherwise. For example,
    "foo/bar" is a path, but "bar" is not.
    """
    return os.path.basename(filename) != filename


[docs]class HTTPDatasetSource(DatasetSource): """ Represents the source of a dataset stored at a web location and referred to by an HTTP or HTTPS URL. """ def __init__(self, url): self._url = url @property def url(self): """The HTTP/S URL referring to the dataset source location. Returns: The HTTP/S URL referring to the dataset source location. """ return self._url @staticmethod def _get_source_type() -> str: return "http" def _extract_filename(self, response) -> str: """ Extracts a filename from the Content-Disposition header or the URL's path. """ if content_disposition := response.headers.get("Content-Disposition"): for match in re.finditer(r"filename=(.+)", content_disposition): filename = match[1].strip("'\"") if _is_path(filename): raise MlflowException.invalid_parameter_value( f"Invalid filename in Content-Disposition header: {filename}. " "It must be a file name, not a path." ) return filename # Extract basename from URL if no valid filename in Content-Disposition return os.path.basename(urlparse(self.url).path)
[docs] def load(self, dst_path=None) -> str: """Downloads the dataset source to the local filesystem. Args: dst_path: Path of the local filesystem destination directory to which to download the dataset source. If the directory does not exist, it is created. If unspecified, the dataset source is downloaded to a new uniquely-named directory on the local filesystem. Returns: The path to the downloaded dataset source on the local filesystem. """ resp = cloud_storage_http_request( method="GET", url=self.url, stream=True, ) augmented_raise_for_status(resp) basename = self._extract_filename(resp) if not basename: basename = "dataset_source" if dst_path is None: dst_path = create_tmp_dir() dst_path = os.path.join(dst_path, basename) with open(dst_path, "wb") as f: chunk_size = 1024 * 1024 # 1 MB for chunk in resp.iter_content(chunk_size=chunk_size): f.write(chunk) return dst_path
@staticmethod def _can_resolve(raw_source: Any) -> bool: """ Args: raw_source: The raw source, e.g. a string like "http://mysite/mydata.tar.gz". Returns: True if this DatasetSource can resolve the raw source, False otherwise. """ if not isinstance(raw_source, str): return False try: parsed_source = urlparse(str(raw_source)) return parsed_source.scheme in ["http", "https"] except Exception: return False @classmethod def _resolve(cls, raw_source: Any) -> "HTTPDatasetSource": """ Args: raw_source: The raw source, e.g. a string like "http://mysite/mydata.tar.gz". """ return HTTPDatasetSource(raw_source)
[docs] def to_dict(self) -> dict[Any, Any]: """ Returns: A JSON-compatible dictionary representation of the HTTPDatasetSource. """ return { "url": self.url, }
[docs] @classmethod def from_dict(cls, source_dict: dict[Any, Any]) -> "HTTPDatasetSource": """ Args: source_dict: A dictionary representation of the HTTPDatasetSource. """ url = source_dict.get("url") if url is None: raise MlflowException( 'Failed to parse HTTPDatasetSource. Missing expected key: "url"', INVALID_PARAMETER_VALUE, ) return cls(url=url)