Source code for mlflow.data.huggingface_dataset_source

from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union

from mlflow.data.dataset_source import DatasetSource

if TYPE_CHECKING:
    import datasets


[docs]class HuggingFaceDatasetSource(DatasetSource): """Represents the source of a Hugging Face dataset used in MLflow Tracking.""" def __init__( self, path: str, config_name: Optional[str] = None, data_dir: Optional[str] = None, data_files: Optional[ Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]] ] = None, split: Optional[Union[str, "datasets.Split"]] = None, revision: Optional[Union[str, "datasets.Version"]] = None, trust_remote_code: Optional[bool] = None, ): """Create a `HuggingFaceDatasetSource` instance. Arguments in `__init__` match arguments of the same name in [`datasets.load_dataset()`](https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/loading_methods#datasets.load_dataset). The only exception is `config_name` matches `name` in `datasets.load_dataset()`, because we need to differentiate from `mlflow.data.Dataset` `name` attribute. Args: path: The path of the Hugging Face dataset, if it is a dataset from HuggingFace hub, `path` must match the hub path, e.g., "databricks/databricks-dolly-15k". config_name: The name of of the Hugging Face dataset configuration. data_dir: The `data_dir` of the Hugging Face dataset configuration. data_files: Paths to source data file(s) for the Hugging Face dataset configuration. split: Which split of the data to load. revision: Version of the dataset script to load. trust_remote_code: Whether to trust remote code from the dataset repo. """ self.path = path self.config_name = config_name self.data_dir = data_dir self.data_files = data_files self.split = split self.revision = revision self.trust_remote_code = trust_remote_code @staticmethod def _get_source_type() -> str: return "hugging_face"
[docs] def load(self, **kwargs): """Load the Hugging Face dataset based on `HuggingFaceDatasetSource`. Args: kwargs: Additional keyword arguments used for loading the dataset with the Hugging Face `datasets.load_dataset()` method. Returns: An instance of `datasets.Dataset`. """ import datasets from packaging.version import Version load_kwargs = { "path": self.path, "name": self.config_name, "data_dir": self.data_dir, "data_files": self.data_files, "split": self.split, "revision": self.revision, } # this argument only exists in >= 2.16.0 if Version(datasets.__version__) >= Version("2.16.0"): load_kwargs["trust_remote_code"] = self.trust_remote_code intersecting_keys = set(load_kwargs.keys()) & set(kwargs.keys()) if intersecting_keys: raise KeyError( f"Found duplicated arguments in `HuggingFaceDatasetSource` and " f"`kwargs`: {intersecting_keys}. Please remove them from `kwargs`." ) load_kwargs.update(kwargs) return datasets.load_dataset(**load_kwargs)
@staticmethod def _can_resolve(raw_source: Any): # NB: Initially, we expect that Hugging Face dataset sources will only be used with # Hugging Face datasets constructed by from_huggingface_dataset, which can create # an instance of HuggingFaceDatasetSource directly without the need for resolution return False @classmethod def _resolve(cls, raw_source: str) -> "HuggingFaceDatasetSource": raise NotImplementedError
[docs] def to_dict(self) -> dict[Any, Any]: return { "path": self.path, "config_name": self.config_name, "data_dir": self.data_dir, "data_files": self.data_files, "split": str(self.split), "revision": self.revision, }
[docs] @classmethod def from_dict(cls, source_dict: dict[Any, Any]) -> "HuggingFaceDatasetSource": return cls( path=source_dict.get("path"), config_name=source_dict.get("config_name"), data_dir=source_dict.get("data_dir"), data_files=source_dict.get("data_files"), split=source_dict.get("split"), revision=source_dict.get("revision"), )