from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union
from mlflow.data.dataset_source import DatasetSource
if TYPE_CHECKING:
import datasets
[docs]class HuggingFaceDatasetSource(DatasetSource):
"""Represents the source of a Hugging Face dataset used in MLflow Tracking."""
def __init__(
self,
path: str,
config_name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[
Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]]]
] = None,
split: Optional[Union[str, "datasets.Split"]] = None,
revision: Optional[Union[str, "datasets.Version"]] = None,
trust_remote_code: Optional[bool] = None,
):
"""Create a `HuggingFaceDatasetSource` instance.
Arguments in `__init__` match arguments of the same name in
[`datasets.load_dataset()`](https://huggingface.co/docs/datasets/v2.14.5/en/package_reference/loading_methods#datasets.load_dataset).
The only exception is `config_name` matches `name` in `datasets.load_dataset()`, because
we need to differentiate from `mlflow.data.Dataset` `name` attribute.
Args:
path: The path of the Hugging Face dataset, if it is a dataset from HuggingFace hub,
`path` must match the hub path, e.g., "databricks/databricks-dolly-15k".
config_name: The name of of the Hugging Face dataset configuration.
data_dir: The `data_dir` of the Hugging Face dataset configuration.
data_files: Paths to source data file(s) for the Hugging Face dataset configuration.
split: Which split of the data to load.
revision: Version of the dataset script to load.
trust_remote_code: Whether to trust remote code from the dataset repo.
"""
self.path = path
self.config_name = config_name
self.data_dir = data_dir
self.data_files = data_files
self.split = split
self.revision = revision
self.trust_remote_code = trust_remote_code
@staticmethod
def _get_source_type() -> str:
return "hugging_face"
[docs] def load(self, **kwargs):
"""Load the Hugging Face dataset based on `HuggingFaceDatasetSource`.
Args:
kwargs: Additional keyword arguments used for loading the dataset with the Hugging Face
`datasets.load_dataset()` method.
Returns:
An instance of `datasets.Dataset`.
"""
import datasets
from packaging.version import Version
load_kwargs = {
"path": self.path,
"name": self.config_name,
"data_dir": self.data_dir,
"data_files": self.data_files,
"split": self.split,
"revision": self.revision,
}
# this argument only exists in >= 2.16.0
if Version(datasets.__version__) >= Version("2.16.0"):
load_kwargs["trust_remote_code"] = self.trust_remote_code
intersecting_keys = set(load_kwargs.keys()) & set(kwargs.keys())
if intersecting_keys:
raise KeyError(
f"Found duplicated arguments in `HuggingFaceDatasetSource` and "
f"`kwargs`: {intersecting_keys}. Please remove them from `kwargs`."
)
load_kwargs.update(kwargs)
return datasets.load_dataset(**load_kwargs)
@staticmethod
def _can_resolve(raw_source: Any):
# NB: Initially, we expect that Hugging Face dataset sources will only be used with
# Hugging Face datasets constructed by from_huggingface_dataset, which can create
# an instance of HuggingFaceDatasetSource directly without the need for resolution
return False
@classmethod
def _resolve(cls, raw_source: str) -> "HuggingFaceDatasetSource":
raise NotImplementedError
[docs] def to_dict(self) -> dict[Any, Any]:
return {
"path": self.path,
"config_name": self.config_name,
"data_dir": self.data_dir,
"data_files": self.data_files,
"split": str(self.split),
"revision": self.revision,
}
[docs] @classmethod
def from_dict(cls, source_dict: dict[Any, Any]) -> "HuggingFaceDatasetSource":
return cls(
path=source_dict.get("path"),
config_name=source_dict.get("config_name"),
data_dir=source_dict.get("data_dir"),
data_files=source_dict.get("data_files"),
split=source_dict.get("split"),
revision=source_dict.get("revision"),
)