Skip to main content

MLflow Spark MLlib Integration

Apache Spark MLlib provides distributed machine learning algorithms for processing large-scale datasets across clusters. MLflow integrates with Spark MLlib to track distributed ML pipelines, manage models, and enable flexible deployment from cluster training to standalone inference.

Why MLflow + Spark MLlib?

Pipeline Tracking

Automatically log Spark ML pipelines with all stages, transformers, and estimators. Track parameters from each pipeline component and maintain complete lineage.

Format Flexibility

Save models in native Spark format for distributed batch processing or PyFunc format for inference outside a Spark cluster with automatic DataFrame conversion.

Datasource Autologging

Track data sources automatically with paths, formats, and versions. Maintain complete data lineage for distributed ML workflows.

Cross-Platform Deployment

Deploy Spark models with PyFunc wrappers for REST APIs and edge computing, or convert to ONNX for platform-independent inference.

Basic Model Logging

Log Spark MLlib models with mlflow.spark.log_model():

python
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession

# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()

# Prepare training data
training = spark.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)

# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])

# Train and log the model
with mlflow.start_run():
model = pipeline.fit(training)

# Log the entire pipeline
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark-pipeline"
)

# Log parameters manually
mlflow.log_params(
{
"max_iter": lr.getMaxIter(),
"reg_param": lr.getRegParam(),
"num_features": hashingTF.getNumFeatures(),
}
)

print(f"Model logged with URI: {model_info.model_uri}")

Automatically logs the complete pipeline with all stages, parameters, and model in both Spark native and PyFunc formats.

Model Formats and Loading

Preserves full Spark ML functionality for distributed processing:

python
# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)

# Use for distributed batch scoring
test_data = spark.createDataFrame(
[(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
["id", "text"],
)

predictions = spark_model.transform(test_data)
predictions.show()

Datasource Autologging

Track data sources automatically during model training:

python
import mlflow.spark

mlflow.spark.autolog()

with mlflow.start_run():
raw_data = spark.read.parquet("s3://my-bucket/training-data/")
model = pipeline.fit(raw_data)
mlflow.spark.log_model(model, artifact_path="model")

Requires Spark 3.0+, MLflow-Spark JAR configuration, and is not supported on Databricks shared/serverless clusters. Logs paths, formats, and versions for all datasource reads.

Model Signatures

Infer signatures automatically for Spark ML models:

python
from mlflow.models import infer_signature
from pyspark.ml.functions import array_to_vector

vector_data = spark.createDataFrame(
[([3.0, 4.0], 0.0), ([5.0, 6.0], 1.0)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")

lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)

predictions = model.transform(vector_data)

# Infer signature from pandas DataFrames
signature = infer_signature(
vector_data.limit(2).toPandas(),
predictions.select("prediction").limit(2).toPandas(),
)

with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model,
artifact_path="vector_model",
signature=signature,
)

ONNX Conversion

Convert Spark models to ONNX (experimental):

python
import onnxmltools

with mlflow.start_run():
model = pipeline.fit(training_data)
mlflow.spark.log_model(spark_model=model, artifact_path="spark_model")

onnx_model = onnxmltools.convert_sparkml(model, name="SparkMLPipeline")
onnxmltools.utils.save_model(onnx_model, "model.onnx")
mlflow.log_artifact("model.onnx")

Model Registry

Register and promote Spark models:

python
from mlflow import MlflowClient

client = MlflowClient()

with mlflow.start_run():
model = pipeline.fit(train_data)

mlflow.spark.log_model(
spark_model=model,
artifact_path="production_candidate",
registered_model_name="CustomerSegmentationModel",
)

mlflow.set_tags(
{
"validation_passed": "true",
"deployment_target": "batch_scoring",
}
)

model_version = client.get_latest_versions(
"CustomerSegmentationModel", stages=["None"]
)[0]

client.transition_model_version_stage(
name="CustomerSegmentationModel", version=model_version.version, stage="Staging"
)

Learn More