Sentence Transformers within MLflow
Attention
The sentence_transformers
flavor is in active development and is marked as Experimental. Public APIs may change and new features are
subject to be added as additional functionality is brought to the flavor.
The sentence_transformers
model flavor enables logging of
sentence-transformers models in MLflow format via
the mlflow.sentence_transformers.save_model()
and mlflow.sentence_transformers.log_model()
functions. Using these
functions also adds the python_function
flavor to the MLflow Models, enabling the model to be
interpreted as a generic Python function for inference via mlflow.pyfunc.load_model()
.
Additionally, mlflow.sentence_transformers.load_model()
can be used to load a saved or logged MLflow
Model with the sentence_transformers
flavor in the native sentence-transformers format.
Tutorials for Sentence Transformers
Looking to get right in to some usable examples and tutorials that show how to leverage this library with MLflow?
See the TutorialsInput and Output Types for PyFunc
The sentence_transformers
python_function (pyfunc) model flavor standardizes
the process of embedding sentences and computing semantic similarity. This standardization allows for serving
and batch inference by adapting the required data structures for sentence_transformers
into formats compatible with JSON serialization and casting to Pandas DataFrames.
Note
The sentence_transformers
flavor supports various models for tasks such as embedding generation, semantic similarity, and paraphrase mining. The specific input and output types will depend on the model and task being performed.
Saving and Logging Sentence Transformers Models
You can save and log sentence-transformers models in MLflow. Here’s an example of both saving and logging a model:
import mlflow
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("model_name")
# Saving the model
mlflow.sentence_transformers.save_model(model=model, path="path/to/save/directory")
# Logging the model
with mlflow.start_run():
mlflow.sentence_transformers.log_model(
sentence_transformers_model=model, artifact_path="model_artifact_path"
)
Saving Sentence Transformers Models with an OpenAI-Compatible Inference Interface
Note
This feature is only available in MLflow 2.11.0 and above.
MLflow’s sentence_transformers
flavor allows you to pass in the task
param with the string value "llm/v1/embeddings"
when saving a model with mlflow.sentence_transformers.save_model()
and mlflow.sentence_transformers.log_model()
.
For example:
import mlflow
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2")
mlflow.sentence_transformers.save_model(
model=model, path="path/to/save/directory", task="llm/v1/embeddings"
)
When task
is set as "llm/v1/embeddings"
, MLflow handles the following for you:
Setting an embeddings compatible signature for the model
Performing data pre- and post-processing to ensure the inputs and outputs conform to the Embeddings API spec, which is compatible with OpenAI’s API spec.
Note that these modifications only apply when the model is loaded with mlflow.pyfunc.load_model()
(e.g. when
serving the model with the mlflow models serve
CLI tool). If you want to load just the base pipeline, you can
always do so via mlflow.sentence_transformers.load_model()
.
Aside from the sentence-transformers
flavor, the transformers
flavor also support OpenAI-compatible inference interface ("llm/v1/chat"
and "llm/v1/completions"
). Refer to
the Transformers flavor guide for more information.
Custom Python Function Implementation
In addition to using pre-built models, you can create custom Python functions with the sentence_transformers flavor. Here’s an example of a custom implementation for comparing the similarity between text documents:
import mlflow
from mlflow.pyfunc import PythonModel
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer, util
class DocumentSimilarityModel(PythonModel):
def load_context(self, context):
"""Load the model context for inference."""
self.model = SentenceTransformer.load(context.artifacts["model_path"])
def predict(self, context, model_input):
"""Predict method for comparing similarity between documents."""
if isinstance(model_input, pd.DataFrame) and model_input.shape[1] == 2:
documents = model_input.values
else:
raise ValueError("Input must be a DataFrame with exactly two columns.")
# Compute embeddings for each document separately
embeddings1 = self.model.encode(documents[:, 0], convert_to_tensor=True)
embeddings2 = self.model.encode(documents[:, 1], convert_to_tensor=True)
# Calculate cosine similarity
similarity_scores = util.cos_sim(embeddings1, embeddings2)
return pd.DataFrame(similarity_scores.numpy(), columns=["similarity_score"])
# Example model saving and loading
model = SentenceTransformer("all-MiniLM-L6-v2")
model_path = "/tmp/sentence_transformers_model"
model.save(model_path)
# Example usage
with mlflow.start_run():
model_info = mlflow.pyfunc.log_model(
artifact_path="document_similarity_model",
python_model=DocumentSimilarityModel(),
artifacts={"model_path": model_path},
)
loaded = mlflow.pyfunc.load_model(model_info.model_uri)
# Test prediction
df = pd.DataFrame(
{
"doc1": ["Sentence Transformers is a wonderful package!"],
"doc2": ["MLflow is pretty great too!"],
}
)
result = loaded.predict(df)
print(result)
Which will generate the similarity score for the documents passed, as shown below:
similarity_score
0 0.275423