Source code for mlflow.entities.dataset_input
from typing import List, Optional
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.dataset import Dataset
from mlflow.entities.input_tag import InputTag
from mlflow.protos.service_pb2 import DatasetInput as ProtoDatasetInput
[docs]class DatasetInput(_MlflowObject):
"""DatasetInput object associated with an experiment."""
def __init__(self, dataset: Dataset, tags: Optional[List[InputTag]] = None) -> None:
self._dataset = dataset
self._tags = tags or []
def __eq__(self, other: _MlflowObject) -> bool:
if type(other) is type(self):
return self.__dict__ == other.__dict__
return False
def _add_tag(self, tag: InputTag) -> None:
self._tags.append(tag)
@property
def tags(self) -> List[InputTag]:
"""Array of input tags."""
return self._tags
@property
def dataset(self) -> Dataset:
"""Dataset."""
return self._dataset
[docs] def to_proto(self):
dataset_input = ProtoDatasetInput()
dataset_input.tags.extend([tag.to_proto() for tag in self.tags])
dataset_input.dataset.MergeFrom(self.dataset.to_proto())
return dataset_input
[docs] @classmethod
def from_proto(cls, proto):
dataset_input = cls(Dataset.from_proto(proto.dataset))
for input_tag in proto.tags:
dataset_input._add_tag(InputTag.from_proto(input_tag))
return dataset_input