MLflow Spark MLlib Integration
Apache Spark MLlib provides distributed machine learning algorithms for processing large-scale datasets across clusters. MLflow integrates with Spark MLlib to track distributed ML pipelines, manage models, and enable flexible deployment from cluster training to standalone inference.
Why MLflow + Spark MLlib?
Pipeline Tracking
Automatically log Spark ML pipelines with all stages, transformers, and estimators. Track parameters from each pipeline component and maintain complete lineage.
Format Flexibility
Save models in native Spark format for distributed batch processing or PyFunc format for inference outside a Spark cluster with automatic DataFrame conversion.
Datasource Autologging
Track data sources automatically with paths, formats, and versions. Maintain complete data lineage for distributed ML workflows.
Cross-Platform Deployment
Deploy Spark models with PyFunc wrappers for REST APIs and edge computing, or convert to ONNX for platform-independent inference.
Basic Model Logging
Log Spark MLlib models with mlflow.spark.log_model():
import mlflow
import mlflow.spark
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import Tokenizer, HashingTF
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
# Initialize Spark session
spark = SparkSession.builder.appName("MLflowSparkExample").getOrCreate()
# Prepare training data
training = spark.createDataFrame(
[
(0, "a b c d e spark", 1.0),
(1, "b d", 0.0),
(2, "spark f g h", 1.0),
(3, "hadoop mapreduce", 0.0),
],
["id", "text", "label"],
)
# Create ML Pipeline
tokenizer = Tokenizer(inputCol="text", outputCol="words")
hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Train and log the model
with mlflow.start_run():
model = pipeline.fit(training)
# Log the entire pipeline
model_info = mlflow.spark.log_model(
spark_model=model, artifact_path="spark-pipeline"
)
# Log parameters manually
mlflow.log_params(
{
"max_iter": lr.getMaxIter(),
"reg_param": lr.getRegParam(),
"num_features": hashingTF.getNumFeatures(),
}
)
print(f"Model logged with URI: {model_info.model_uri}")
Automatically logs the complete pipeline with all stages, parameters, and model in both Spark native and PyFunc formats.
Model Formats and Loading
- Native Spark Format
- PyFunc Format
Preserves full Spark ML functionality for distributed processing:
# Load as native Spark model (requires Spark session)
spark_model = mlflow.spark.load_model(model_info.model_uri)
# Use for distributed batch scoring
test_data = spark.createDataFrame(
[(4, "spark i j k"), (5, "l m n"), (6, "spark hadoop spark"), (7, "apache hadoop")],
["id", "text"],
)
predictions = spark_model.transform(test_data)
predictions.show()
Enables inference outside a Spark cluster:
import pandas as pd
# Load as PyFunc model
pyfunc_model = mlflow.pyfunc.load_model(model_info.model_uri)
# Use with pandas DataFrame
test_data = pd.DataFrame(
{"text": ["spark machine learning", "hadoop distributed computing"]}
)
predictions = pyfunc_model.predict(test_data)
print(predictions)
PyFunc automatically converts pandas DataFrames to Spark format and creates a local Spark session for inference. Note that the Apache Spark library is still required as a dependency.
Datasource Autologging
Track data sources automatically during model training:
import mlflow.spark
mlflow.spark.autolog()
with mlflow.start_run():
raw_data = spark.read.parquet("s3://my-bucket/training-data/")
model = pipeline.fit(raw_data)
mlflow.spark.log_model(model, artifact_path="model")
Requires Spark 3.0+, MLflow-Spark JAR configuration, and is not supported on Databricks shared/serverless clusters. Logs paths, formats, and versions for all datasource reads.
Model Signatures
Infer signatures automatically for Spark ML models:
from mlflow.models import infer_signature
from pyspark.ml.functions import array_to_vector
vector_data = spark.createDataFrame(
[([3.0, 4.0], 0.0), ([5.0, 6.0], 1.0)], ["features_array", "label"]
).select(array_to_vector("features_array").alias("features"), "label")
lr = LogisticRegression(featuresCol="features", labelCol="label")
model = lr.fit(vector_data)
predictions = model.transform(vector_data)
# Infer signature from pandas DataFrames
signature = infer_signature(
vector_data.limit(2).toPandas(),
predictions.select("prediction").limit(2).toPandas(),
)
with mlflow.start_run():
mlflow.spark.log_model(
spark_model=model,
artifact_path="vector_model",
signature=signature,
)
ONNX Conversion
Convert Spark models to ONNX (experimental):
import onnxmltools
with mlflow.start_run():
model = pipeline.fit(training_data)
mlflow.spark.log_model(spark_model=model, artifact_path="spark_model")
onnx_model = onnxmltools.convert_sparkml(model, name="SparkMLPipeline")
onnxmltools.utils.save_model(onnx_model, "model.onnx")
mlflow.log_artifact("model.onnx")
Model Registry
Register and promote Spark models:
from mlflow import MlflowClient
client = MlflowClient()
with mlflow.start_run():
model = pipeline.fit(train_data)
mlflow.spark.log_model(
spark_model=model,
artifact_path="production_candidate",
registered_model_name="CustomerSegmentationModel",
)
mlflow.set_tags(
{
"validation_passed": "true",
"deployment_target": "batch_scoring",
}
)
model_version = client.get_latest_versions(
"CustomerSegmentationModel", stages=["None"]
)[0]
client.transition_model_version_stage(
name="CustomerSegmentationModel", version=model_version.version, stage="Staging"
)
Learn More
Model Registry
Manage model versions, aliases, and lifecycle stages for production deployment workflows.
Model Signatures
Define input and output schemas for model validation and type checking.
Model Deployment
Deploy Spark models with MLflow serving, batch inference, and cloud platforms.
MLflow Tracking
Track experiments, parameters, metrics, and artifacts across ML workflows.