Source code for mlflow.data.spark_dataset_source

from typing import Any, Dict, Optional

from mlflow.data.dataset_source import DatasetSource
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE


[docs]class SparkDatasetSource(DatasetSource): """ Represents the source of a dataset stored in a spark table. """ def __init__( self, path: Optional[str] = None, table_name: Optional[str] = None, sql: Optional[str] = None, ): if (path, table_name, sql).count(None) != 2: raise MlflowException( 'Must specify exactly one of "path", "table_name", or "sql"', INVALID_PARAMETER_VALUE, ) self._path = path self._table_name = table_name self._sql = sql @staticmethod def _get_source_type() -> str: return "spark"
[docs] def load(self, **kwargs): """Loads the dataset source as a Spark Dataset Source. Returns: An instance of ``pyspark.sql.DataFrame``. """ from pyspark.sql import SparkSession spark = SparkSession.builder.getOrCreate() if self._path: return spark.read.parquet(self._path) if self._table_name: return spark.read.table(self._table_name) if self._sql: return spark.sql(self._sql)
@staticmethod def _can_resolve(raw_source: Any): return False @classmethod def _resolve(cls, raw_source: str) -> "SparkDatasetSource": raise NotImplementedError
[docs] def to_dict(self) -> Dict[Any, Any]: info = {} if self._path is not None: info["path"] = self._path elif self._table_name is not None: info["table_name"] = self._table_name elif self._sql is not None: info["sql"] = self._sql return info
[docs] @classmethod def from_dict(cls, source_dict: Dict[Any, Any]) -> "SparkDatasetSource": return cls( path=source_dict.get("path"), table_name=source_dict.get("table_name"), sql=source_dict.get("sql"), )