Source code for mlflow.data.tensorflow_dataset

import json
import logging
from functools import cached_property
from typing import Any, Optional, Union

import numpy as np

from mlflow.data.dataset import Dataset
from mlflow.data.dataset_source import DatasetSource
from mlflow.data.digest_utils import (
    MAX_ROWS,
    compute_numpy_digest,
    get_normalized_md5_digest,
)
from mlflow.data.evaluation_dataset import EvaluationDataset
from mlflow.data.pyfunc_dataset_mixin import PyFuncConvertibleDatasetMixin, PyFuncInputsOutputs
from mlflow.data.schema import TensorDatasetSchema
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INTERNAL_ERROR, INVALID_PARAMETER_VALUE
from mlflow.types.schema import Schema
from mlflow.types.utils import _infer_schema

_logger = logging.getLogger(__name__)


[docs]class TensorFlowDataset(Dataset, PyFuncConvertibleDatasetMixin): """ Represents a TensorFlow dataset for use with MLflow Tracking. """ def __init__( self, features, source: DatasetSource, targets=None, name: Optional[str] = None, digest: Optional[str] = None, ): """ Args: features: A TensorFlow dataset or tensor of features. source: The source of the TensorFlow dataset. targets: A TensorFlow dataset or tensor of targets. Optional. name: The name of the dataset. E.g. "wiki_train". If unspecified, a name is automatically generated. digest: The digest (hash, fingerprint) of the dataset. If unspecified, a digest is automatically computed. """ import tensorflow as tf if not isinstance(features, tf.data.Dataset) and not tf.is_tensor(features): raise MlflowException( f"'features' must be an instance of tf.data.Dataset or a TensorFlow Tensor." f" Found: {type(features)}.", INVALID_PARAMETER_VALUE, ) if tf.is_tensor(features) and targets is not None and not tf.is_tensor(targets): raise MlflowException( f"If 'features' is a TensorFlow Tensor, then 'targets' must also be a TensorFlow" f" Tensor. Found: {type(targets)}.", INVALID_PARAMETER_VALUE, ) if ( isinstance(features, tf.data.Dataset) and targets is not None and not isinstance(targets, tf.data.Dataset) ): raise MlflowException( "If 'features' is an instance of tf.data.Dataset, then 'targets' must also be an" f" instance of tf.data.Dataset. Found: {type(targets)}.", INVALID_PARAMETER_VALUE, ) self._features = features self._targets = targets super().__init__(source=source, name=name, digest=digest) def _compute_tensorflow_dataset_digest( # noqa: D417 self, dataset, targets=None, ) -> str: """Computes a digest for the given Tensorflow dataset. Args: dataset: A Tensorflow dataset. Returns: A string digest. """ import pandas as pd import tensorflow as tf hashable_elements = [] def hash_tf_dataset_iterator_element(element): if element is None: return flat_element = tf.nest.flatten(element) flattened_array = np.concatenate([x.flatten() for x in flat_element]) trimmed_array = flattened_array[0:MAX_ROWS] try: hashable_elements.append(pd.util.hash_array(trimmed_array)) except TypeError: hashable_elements.append(np.int64(trimmed_array.size)) for element in dataset.as_numpy_iterator(): hash_tf_dataset_iterator_element(element) if targets is not None: for element in targets.as_numpy_iterator(): hash_tf_dataset_iterator_element(element) return get_normalized_md5_digest(hashable_elements) def _compute_tensor_digest( self, tensor_data, tensor_targets, ) -> str: """Computes a digest for the given Tensorflow tensor. Args: tensor_data: A Tensorflow tensor, representing the features. tensor_targets: A Tensorflow tensor, representing the targets. Optional. Returns: A string digest. """ if tensor_targets is None: return compute_numpy_digest(tensor_data.numpy()) else: return compute_numpy_digest(tensor_data.numpy(), tensor_targets.numpy()) def _compute_digest(self) -> str: """ Computes a digest for the dataset. Called if the user doesn't supply a digest when constructing the dataset. """ import tensorflow as tf if isinstance(self._features, tf.data.Dataset): return self._compute_tensorflow_dataset_digest(self._features, self._targets) return self._compute_tensor_digest(self._features, self._targets)
[docs] def to_dict(self) -> dict[str, str]: """Create config dictionary for the dataset. Returns a string dictionary containing the following fields: name, digest, source, source type, schema, and profile. """ schema = json.dumps(self.schema.to_dict()) if self.schema else None config = super().to_dict() config.update( { "schema": schema, "profile": json.dumps(self.profile), } ) return config
@property def data(self): """ The underlying TensorFlow data. """ return self._features @property def source(self) -> DatasetSource: """ The source of the dataset. """ return self._source @property def targets(self): """ The targets of the dataset. """ return self._targets @property def profile(self) -> Optional[Any]: """ A profile of the dataset. May be None if no profile is available. """ import tensorflow as tf profile = { "features_cardinality": int(self._features.cardinality().numpy()) if isinstance(self._features, tf.data.Dataset) else int(tf.size(self._features).numpy()), } if self._targets is not None: profile.update( { "targets_cardinality": int(self._targets.cardinality().numpy()) if isinstance(self._targets, tf.data.Dataset) else int(tf.size(self._targets).numpy()), } ) return profile @cached_property def schema(self) -> Optional[TensorDatasetSchema]: """ An MLflow TensorSpec schema representing the tensor dataset """ try: features_schema = TensorFlowDataset._get_tf_object_schema(self._features) targets_schema = None if self._targets is not None: targets_schema = TensorFlowDataset._get_tf_object_schema(self._targets) return TensorDatasetSchema(features=features_schema, targets=targets_schema) except Exception as e: _logger.warning("Failed to infer schema for TensorFlow dataset. Exception: %s", e) return None @staticmethod def _get_tf_object_schema(tf_object) -> Schema: import tensorflow as tf if isinstance(tf_object, tf.data.Dataset): numpy_data = next(tf_object.as_numpy_iterator()) if isinstance(numpy_data, np.ndarray): return _infer_schema(numpy_data) elif isinstance(numpy_data, dict): return TensorFlowDataset._get_schema_from_tf_dataset_dict_numpy_data(numpy_data) elif isinstance(numpy_data, tuple): return TensorFlowDataset._get_schema_from_tf_dataset_tuple_numpy_data(numpy_data) else: raise MlflowException( f"Failed to infer schema for tf.data.Dataset due to unrecognized numpy iterator" f" data type. Numpy iterator data types 'np.ndarray', 'dict', and 'tuple' are" f" supported. Found: {type(numpy_data)}.", INVALID_PARAMETER_VALUE, ) elif tf.is_tensor(tf_object): return _infer_schema(tf_object.numpy()) else: raise MlflowException( f"Cannot infer schema of an object that is not an instance of tf.data.Dataset or" f" a TensorFlow Tensor. Found: {type(tf_object)}", INTERNAL_ERROR, ) @staticmethod def _get_schema_from_tf_dataset_dict_numpy_data(numpy_data: dict[Any, Any]) -> Schema: if not all(isinstance(data_element, np.ndarray) for data_element in numpy_data.values()): raise MlflowException( "Failed to infer schema for tf.data.Dataset. Schemas can only be inferred" " if the dataset consists of tensors. Ragged tensors, tensor arrays, and" " other types are not supported. Additionally, datasets with nested tensors" " are not supported.", INVALID_PARAMETER_VALUE, ) return _infer_schema(numpy_data) @staticmethod def _get_schema_from_tf_dataset_tuple_numpy_data(numpy_data: tuple[Any]) -> Schema: if not all(isinstance(data_element, np.ndarray) for data_element in numpy_data): raise MlflowException( "Failed to infer schema for tf.data.Dataset. Schemas can only be inferred" " if the dataset consists of tensors. Ragged tensors, tensor arrays, and" " other types are not supported. Additionally, datasets with nested tensors" " are not supported.", INVALID_PARAMETER_VALUE, ) return _infer_schema( { # MLflow Schemas currently require each tensor to have a name, if more than # one tensor is defined. Accordingly, use the index as the name str(i): data_element for i, data_element in enumerate(numpy_data) } ) def to_pyfunc(self) -> PyFuncInputsOutputs: """ Converts the dataset to a collection of pyfunc inputs and outputs for model evaluation. Required for use with mlflow.evaluate(). """ return PyFuncInputsOutputs(self._features, self._targets)
[docs] def to_evaluation_dataset(self, path=None, feature_names=None) -> EvaluationDataset: """ Converts the dataset to an EvaluationDataset for model evaluation. Only supported if the dataset is a Tensor. Required for use with mlflow.evaluate(). """ import tensorflow as tf # check that data and targets are Tensors if not tf.is_tensor(self._features): raise MlflowException("Data must be a Tensor to convert to an EvaluationDataset.") if self._targets is not None and not tf.is_tensor(self._targets): raise MlflowException("Targets must be a Tensor to convert to an EvaluationDataset.") return EvaluationDataset( data=self._features.numpy(), targets=self._targets.numpy() if self._targets is not None else None, path=path, feature_names=feature_names, )
[docs]def from_tensorflow( features, source: Optional[Union[str, DatasetSource]] = None, targets=None, name: Optional[str] = None, digest: Optional[str] = None, ) -> TensorFlowDataset: """Constructs a TensorFlowDataset object from TensorFlow data, optional targets, and source. If the source is path like, then this will construct a DatasetSource object from the source path. Otherwise, the source is assumed to be a DatasetSource object. Args: features: A TensorFlow dataset or tensor of features. source: The source from which the data was derived, e.g. a filesystem path, an S3 URI, an HTTPS URL, a delta table name with version, or spark table etc. If source is not a path like string, pass in a DatasetSource object directly. If no source is specified, a CodeDatasetSource is used, which will source information from the run context. targets: A TensorFlow dataset or tensor of targets. Optional. name: The name of the dataset. If unspecified, a name is generated. digest: A dataset digest (hash). If unspecified, a digest is computed automatically. """ from mlflow.data.code_dataset_source import CodeDatasetSource from mlflow.data.dataset_source_registry import resolve_dataset_source from mlflow.tracking.context import registry if source is not None: if isinstance(source, DatasetSource): resolved_source = source else: resolved_source = resolve_dataset_source( source, ) else: context_tags = registry.resolve_tags() resolved_source = CodeDatasetSource(tags=context_tags) return TensorFlowDataset( features=features, source=resolved_source, targets=targets, name=name, digest=digest )