mlflow.transformers
MLflow module for HuggingFace/transformer support.
-
mlflow.transformers.
autolog
(log_input_examples=False, log_model_signatures=False, log_models=False, log_datasets=False, disable=False, exclusive=False, disable_for_unsupported_versions=False, silent=False, extra_tags=None)[source] Note
Experimental: This function may change or be removed in a future release without warning.
Note
Autologging is known to be compatible with the following package versions:
4.25.1
<=transformers
<=4.45.2
. Autologging may not succeed when used with package versions outside of this range.This autologging integration is solely used for disabling spurious autologging of irrelevant sub-models that are created during the training and evaluation of transformers-based models. Autologging functionality is not implemented fully for the transformers flavor.
-
mlflow.transformers.
generate_signature_output
(pipeline, data, model_config=None, params=None, flavor_config=None)[source] Note
Experimental: This function may change or be removed in a future release without warning.
Utility for generating the response output for the purposes of extracting an output signature for model saving and logging. This function simulates loading of a saved model or pipeline as a
pyfunc
model without having to incur a write to disk.- Parameters
pipeline – A
transformers
pipeline object. Note that component-level or model-level inputs are not permitted for extracting an output example.data – An example input that is compatible with the given pipeline
model_config – Any additional model configuration, provided as kwargs, to inform the format of the output type from a pipeline inference call.
params – A dictionary of additional parameters to pass to the pipeline for inference.
flavor_config – The flavor configuration for the model.
- Returns
The output from the
pyfunc
pipeline wrapper’spredict
method
-
mlflow.transformers.
get_default_conda_env
(model)[source] Note
Experimental: This function may change or be removed in a future release without warning.
- Returns
The default Conda environment for MLflow Models produced with the
transformers
flavor, based on the model instance framework type of the model to be logged.
-
mlflow.transformers.
get_default_pip_requirements
(model) → List[str][source] Note
Experimental: This function may change or be removed in a future release without warning.
- Parameters
model – The model instance to be saved in order to provide the required underlying deep learning execution framework dependency requirements. Note that this must be the actual model instance and not a Pipeline.
- Returns
A list of default pip requirements for MLflow Models that have been produced with the
transformers
flavor. Calls tosave_model()
andlog_model()
produce a pip environment that contain these requirements at a minimum.
-
mlflow.transformers.
is_gpu_available
()[source]
-
mlflow.transformers.
load_model
(model_uri: str, dst_path: Optional[str] = None, return_type='pipeline', device=None, **kwargs)[source] Note
Experimental: This function may change or be removed in a future release without warning.
Note
The ‘transformers’ MLflow Models integration is known to be compatible with
4.25.1
<=transformers
<=4.45.2
. MLflow Models integrations with transformers may not succeed when used with package versions outside of this range.Load a
transformers
object from a local file or a run.- Parameters
model_uri –
The location, in URI format, of the MLflow model. For example:
/Users/me/path/to/local/model
relative/path/to/local/model
s3://my_bucket/path/to/model
runs:/<mlflow_run_id>/run-relative/path/to/model
mlflow-artifacts:/path/to/model
For more information about supported URI schemes, see Referencing Artifacts.
dst_path – The local filesystem path to utilize for downloading the model artifact. This directory must already exist if provided. If unspecified, a local output path will be created.
return_type –
A return type modifier for the stored
transformers
object. If set as “components”, the return type will be a dictionary of the saved individual components of either thePipeline
or the pre-trained model. The components for NLP-focused models will typically consist of a return representation as shown below with a text-classification example:{"model": BertForSequenceClassification, "tokenizer": BertTokenizerFast}
Vision models will return an
ImageProcessor
instance of the appropriate type, while multi-modal models will return both aFeatureExtractor
and aTokenizer
along with the model. Returning “components” can be useful for certain model types that do not have the desired pipeline return types for certain use cases. If set as “pipeline”, the model, along with any and all requiredTokenizer
,FeatureExtractor
,Processor
, orImageProcessor
objects will be returned within aPipeline
object of the appropriate type defined by thetask
set by the model instance type. To override this behavior, supply a validtask
argument during model logging or saving. Default is “pipeline”.device – The device on which to load the model. Default is None. Use 0 to load to the default GPU.
kwargs – Optional configuration options for loading of a
transformers
object. For information on parameters and their usage, see transformers documentation.
- Returns
A
transformers
model instance or a dictionary of components
-
mlflow.transformers.
log_model
(transformers_model, artifact_path: str, processor=None, task: Optional[str] = None, torch_dtype: Optional[torch.dtype] = None, model_card=None, inference_config: Optional[Dict[str, Any]] = None, code_paths: Optional[List[str]] = None, registered_model_name: Optional[str] = None, signature: Optional[ModelSignature] = None, input_example: Optional[ModelInputExample] = None, await_registration_for=300, pip_requirements: Optional[Union[List[str], str]] = None, extra_pip_requirements: Optional[Union[List[str], str]] = None, conda_env=None, metadata: Optional[Dict[str, Any]] = None, model_config: Optional[Dict[str, Any]] = None, example_no_conversion: Optional[bool] = None, prompt_template: Optional[str] = None, save_pretrained: bool = True, **kwargs)[source] Note
Experimental: This function may change or be removed in a future release without warning.
Note
The ‘transformers’ MLflow Models integration is known to be compatible with
4.25.1
<=transformers
<=4.45.2
. MLflow Models integrations with transformers may not succeed when used with package versions outside of this range.Log a
transformers
object as an MLflow artifact for the current run. Note that logging transformers models with custom code (i.e. models that requiretrust_remote_code=True
) requirestransformers >= 4.26.0
.- Parameters
transformers_model –
The transformers model to save. This can be one of the following format:
A transformers Pipeline instance.
- A dictionary that maps required components of a pipeline to the named keys
of [“model”, “image_processor”, “tokenizer”, “feature_extractor”]. The model key in the dictionary must map to a value that inherits from PreTrainedModel, TFPreTrainedModel, or FlaxPreTrainedModel. All other component entries in the dictionary must support the defined task type that is associated with the base model type configuration.
- A string that represents a path to a local/DBFS directory containing a model
checkpoint. The directory must contain a config.json file that is required for loading the transformers model. This is particularly useful when logging a model that cannot be loaded into memory for serialization.
An example of specifying a Pipeline from a default pipeline instantiation:
from transformers import pipeline qa_pipe = pipeline("question-answering", "csarron/mobilebert-uncased-squad-v2") with mlflow.start_run(): mlflow.transformers.log_model( transformers_model=qa_pipe, artifact_path="model", )
An example of specifying component-level parts of a transformers model is shown below:
from transformers import MobileBertForQuestionAnswering, AutoTokenizer architecture = "csarron/mobilebert-uncased-squad-v2" tokenizer = AutoTokenizer.from_pretrained(architecture) model = MobileBertForQuestionAnswering.from_pretrained(architecture) with mlflow.start_run(): components = { "model": model, "tokenizer": tokenizer, } mlflow.transformers.log_model( transformers_model=components, artifact_path="model", )
An example of specifying a local checkpoint path is shown below:
with mlflow.start_run(): mlflow.transformers.log_model( transformers_model="path/to/local/checkpoint", artifact_path="model", )
artifact_path – Local path destination for the serialized model to be saved.
processor –
An optional
Processor
subclass object. Some model architectures, particularly multi-modal types, utilize Processors to combine text encoding and image or audio encoding in a single entrypoint.Note
If a processor is supplied when logging a model, the model will be unavailable for loading as a
Pipeline
or for usage with pyfunc inference.task – The transformers-specific task type of the model. These strings are utilized so that a pipeline can be created with the appropriate internal call architecture to meet the needs of a given model. If this argument is not specified, the pipeline utilities within the transformers library will be used to infer the correct task type. If the value specified is not a supported type within the version of transformers that is currently installed, an Exception will be thrown.
torch_dtype – The Pytorch dtype applied to the model when loading back. This is useful when you want to save the model with a specific dtype that is different from the dtype of the model when it was trained. If not specified, the current dtype of the model instance will be used.
model_card –
An Optional ModelCard instance from huggingface-hub. If provided, the contents of the model card will be saved along with the provided transformers_model. If not provided, an attempt will be made to fetch the card from the base pretrained model that is provided (or the one that is included within a provided Pipeline).
Note
In order for a ModelCard to be fetched (if not provided), the huggingface_hub package must be installed and the version must be >=0.10.0
inference_config –
Warning
Deprecated. inference_config is deprecated in favor of model_config.
code_paths –
A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded. Files declared as dependencies for a given model should have relative imports declared from a common root path if multiple files are defined with import dependencies between them to avoid import errors when loading the model.
For a detailed explanation of
code_paths
functionality, recommended usage patterns and limitations, see the code_paths usage guide.registered_model_name – This argument may change or be removed in a future release without warning. If given, create a model version under
registered_model_name
, also creating a registered model if one with the given name does not exist.signature –
A Model Signature object that describes the input and output Schema of the model. The model signature can be inferred using infer_signature function of mlflow.models.signature.
from mlflow.models import infer_signature from mlflow.transformers import generate_signature_output from transformers import pipeline en_to_de = pipeline("translation_en_to_de") data = "MLflow is great!" output = generate_signature_output(en_to_de, data) signature = infer_signature(data, output) with mlflow.start_run() as run: mlflow.transformers.log_model( transformers_model=en_to_de, artifact_path="english_to_german_translator", signature=signature, input_example=data, ) model_uri = f"runs:/{run.info.run_id}/english_to_german_translator" loaded = mlflow.pyfunc.load_model(model_uri) print(loaded.predict(data)) # MLflow ist großartig!
If an input_example is provided and the signature is not, a signature will be inferred automatically and applied to the MLmodel file iff the pipeline type is a text-based model (NLP). If the pipeline type is not a supported type, this inference functionality will not function correctly and a warning will be issued. In order to ensure that a precise signature is logged, it is recommended to explicitly provide one.
input_example – one or several instances of valid model input. The input example is used as a hint of what data to feed the model. It will be converted to a Pandas DataFrame and then serialized to json using the Pandas split-oriented format, or a numpy array where the example will be serialized to json by converting it to a list. Bytes are base64-encoded. When the
signature
parameter isNone
, the input example is used to infer a model signature.await_registration_for – Number of seconds to wait for the model version to finish being created and is in
READY
status. By default, the function waits for five minutes. Specify 0 or None to skip waiting.pip_requirements – Either an iterable of pip requirement strings (e.g.
["transformers", "-r requirements.txt", "-c constraints.txt"]
) or the string path to a pip requirements file on the local filesystem (e.g."requirements.txt"
). If provided, this describes the environment this model should be run in. IfNone
, a default list of requirements is inferred bymlflow.models.infer_pip_requirements()
from the current software environment. If the requirement inference fails, it falls back to usingget_default_pip_requirements()
. Both requirements and constraints are automatically parsed and written torequirements.txt
andconstraints.txt
files, respectively, and stored as part of the model. Requirements are also written to thepip
section of the model’s conda environment (conda.yaml
) file.extra_pip_requirements –
Either an iterable of pip requirement strings (e.g.
["pandas", "-r requirements.txt", "-c constraints.txt"]
) or the string path to a pip requirements file on the local filesystem (e.g."requirements.txt"
). If provided, this describes additional pip requirements that are appended to a default set of pip requirements generated automatically based on the user’s current software environment. Both requirements and constraints are automatically parsed and written torequirements.txt
andconstraints.txt
files, respectively, and stored as part of the model. Requirements are also written to thepip
section of the model’s conda environment (conda.yaml
) file.Warning
The following arguments can’t be specified at the same time:
conda_env
pip_requirements
extra_pip_requirements
This example demonstrates how to specify pip requirements using
pip_requirements
andextra_pip_requirements
.conda_env –
Either a dictionary representation of a Conda environment or the path to a conda environment yaml file. If provided, this describes the environment this model should be run in. At a minimum, it should specify the dependencies contained in
get_default_conda_env()
. IfNone
, a conda environment with pip requirements inferred bymlflow.models.infer_pip_requirements()
is added to the model. If the requirement inference fails, it falls back to usingget_default_pip_requirements()
. pip requirements fromconda_env
are written to a piprequirements.txt
file and the full conda environment is written toconda.yaml
. The following is an example dictionary representation of a conda environment:{ "name": "mlflow-env", "channels": ["conda-forge"], "dependencies": [ "python=3.8.15", { "pip": [ "transformers==x.y.z" ], }, ], }
metadata – Custom metadata dictionary passed to the model and stored in the MLmodel file.
model_config –
A dict of valid overrides that can be applied to a pipeline instance during inference. These arguments are used exclusively for the case of loading the model as a
pyfunc
Model or for use in Spark. These values are not applied to a returned Pipeline from a call tomlflow.transformers.load_model()
Warning
If the key provided is not compatible with either the Pipeline instance for the task provided or is not a valid override to any arguments available in the Model, an Exception will be raised at runtime. It is very important to validate the entries in this dictionary to ensure that they are valid prior to saving or logging.
An example of providing overrides for a question generation model:
from transformers import pipeline, AutoTokenizer task = "text-generation" architecture = "gpt2" sentence_pipeline = pipeline( task=task, tokenizer=AutoTokenizer.from_pretrained(architecture), model=architecture, ) # Validate that the overrides function prompts = ["Generative models are", "I'd like a coconut so that I can"] # validation of config prior to save or log model_config = { "top_k": 2, "num_beams": 5, "max_length": 30, "temperature": 0.62, "top_p": 0.85, "repetition_penalty": 1.15, } # Verify that no exceptions are thrown sentence_pipeline(prompts, **model_config) with mlflow.start_run(): mlflow.transformers.log_model( transformers_model=sentence_pipeline, artifact_path="my_sentence_generator", task=task, model_config=model_config, )
example_no_conversion – This parameter is deprecated and will be removed in a future release. It’s no longer used and can be safely removed. Input examples are not converted anymore.
prompt_template –
A string that, if provided, will be used to format the user’s input prior to inference. The string should contain a single placeholder,
{prompt}
, which will be replaced with the user’s input. For example:"Answer the following question. Q: {prompt} A:"
.Currently, only the following pipeline types are supported:
save_pretrained –
If set to
False
, MLflow will not save the Transformer model weight files, instead only saving the reference to the HuggingFace Hub model repository and its commit hash. This is useful when you load the pretrained model from HuggingFace Hub and want to log or save it to MLflow without modifying the model weights. In such case, specifying this flag toFalse
will save the storage space and reduce time to save the model. Please refer to the Storage-Efficient Model Logging for more detailed usage.Warning
If the model is saved with
save_pretrained
set toFalse
, the model cannot be registered to the MLflow Model Registry. In order to convert the model to the one that can be registered, you can usemlflow.transformers.persist_pretrained_model()
to download the model weights from the HuggingFace Hub and save it in the existing model artifacts. Please refer to Transformers flavor documentation for more detailed usage.import mlflow.transformers model_uri = "YOUR_MODEL_URI_LOGGED_WITH_SAVE_PRETRAINED_FALSE" model = mlflow.transformers.persist_pretrained_model(model_uri) mlflow.register_model(model_uri, "model_name")
Important
When you save the PEFT model, MLflow will override the save_pretrained flag to False and only store the PEFT adapter weights. The base model weights are not saved but the reference to the HuggingFace repository and its commit hash are logged instead.
kwargs – Additional arguments for
mlflow.models.model.Model
-
mlflow.transformers.
persist_pretrained_model
(model_uri: str) → None[source] Persist Transformers pretrained model weights to the artifacts directory of the specified model_uri. This API is primary used for updating an MLflow Model that was logged or saved with setting save_pretrained=False. Such models cannot be registered to Databricks Workspace Model Registry, due to the full pretrained model weights being absent in the artifacts. Transformers models saved in this mode store only the reference to the HuggingFace Hub repository. This API will download the model weights from the HuggingFace Hub repository and save them in the artifacts of the given model_uri so that the model can be registered to Databricks Workspace Model Registry.
- Parameters
model_uri – The URI of the existing MLflow Model of the Transformers flavor. It must be logged/saved with save_pretrained=False.
Examples:
import mlflow # Saving a model with save_pretrained=False with mlflow.start_run() as run: model = pipeline("question-answering", "csarron/mobilebert-uncased-squad-v2") mlflow.transformers.log_model( transformers_model=model, artifact_path="pipeline", save_pretrained=False ) # The model cannot be registered to the Model Registry as it is try: mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline") except MlflowException as e: print(e.message) # Use this API to persist the pretrained model weights mlflow.transformers.persist_pretrained_model(f"runs:/{run.info.run_id}/pipeline") # Now the model can be registered to the Model Registry mlflow.register_model(f"runs:/{run.info.run_id}/pipeline", "qa_pipeline")
-
mlflow.transformers.
save_model
(transformers_model, path: str, processor=None, task: Optional[str] = None, torch_dtype: Optional[torch.dtype] = None, model_card=None, inference_config: Optional[Dict[str, Any]] = None, code_paths: Optional[List[str]] = None, mlflow_model: Optional[Model] = None, signature: Optional[ModelSignature] = None, input_example: Optional[ModelInputExample] = None, pip_requirements: Optional[Union[List[str], str]] = None, extra_pip_requirements: Optional[Union[List[str], str]] = None, conda_env=None, metadata: Optional[Dict[str, Any]] = None, model_config: Optional[Dict[str, Any]] = None, example_no_conversion: Optional[bool] = None, prompt_template: Optional[str] = None, save_pretrained: bool = True, **kwargs) → None[source] Note
Experimental: This function may change or be removed in a future release without warning.
Note
The ‘transformers’ MLflow Models integration is known to be compatible with
4.25.1
<=transformers
<=4.45.2
. MLflow Models integrations with transformers may not succeed when used with package versions outside of this range.Save a trained transformers model to a path on the local file system. Note that saving transformers models with custom code (i.e. models that require
trust_remote_code=True
) requirestransformers >= 4.26.0
.- Parameters
transformers_model –
The transformers model to save. This can be one of the following format:
A transformers Pipeline instance.
- A dictionary that maps required components of a pipeline to the named keys
of [“model”, “image_processor”, “tokenizer”, “feature_extractor”]. The model key in the dictionary must map to a value that inherits from PreTrainedModel, TFPreTrainedModel, or FlaxPreTrainedModel. All other component entries in the dictionary must support the defined task type that is associated with the base model type configuration.
- A string that represents a path to a local/DBFS directory containing a model
checkpoint. The directory must contain a config.json file that is required for loading the transformers model. This is particularly useful when logging a model that cannot be loaded into memory for serialization.
An example of specifying a Pipeline from a default pipeline instantiation:
from transformers import pipeline qa_pipe = pipeline("question-answering", "csarron/mobilebert-uncased-squad-v2") with mlflow.start_run(): mlflow.transformers.save_model( transformers_model=qa_pipe, path="path/to/save/model", )
An example of specifying component-level parts of a transformers model is shown below:
from transformers import MobileBertForQuestionAnswering, AutoTokenizer architecture = "csarron/mobilebert-uncased-squad-v2" tokenizer = AutoTokenizer.from_pretrained(architecture) model = MobileBertForQuestionAnswering.from_pretrained(architecture) with mlflow.start_run(): components = { "model": model, "tokenizer": tokenizer, } mlflow.transformers.save_model( transformers_model=components, path="path/to/save/model", )
An example of specifying a local checkpoint path is shown below:
with mlflow.start_run(): mlflow.transformers.save_model( transformers_model="path/to/local/checkpoint", path="path/to/save/model", )
path – Local path destination for the serialized model to be saved.
processor –
An optional
Processor
subclass object. Some model architectures, particularly multi-modal types, utilize Processors to combine text encoding and image or audio encoding in a single entrypoint.Note
If a processor is supplied when saving a model, the model will be unavailable for loading as a
Pipeline
or for usage with pyfunc inference.task – The transformers-specific task type of the model, or MLflow inference task type. If provided a transformers-specific task type, these strings are utilized so that a pipeline can be created with the appropriate internal call architecture to meet the needs of a given model. If this argument is provided as a inference task type or not specified, the pipeline utilities within the transformers library will be used to infer the correct task type. If the value specified is not a supported type, an Exception will be thrown.
torch_dtype – The Pytorch dtype applied to the model when loading back. This is useful when you want to save the model with a specific dtype that is different from the dtype of the model when it was trained. If not specified, the current dtype of the model instance will be used.
model_card –
An Optional ModelCard instance from huggingface-hub. If provided, the contents of the model card will be saved along with the provided transformers_model. If not provided, an attempt will be made to fetch the card from the base pretrained model that is provided (or the one that is included within a provided Pipeline).
Note
In order for a ModelCard to be fetched (if not provided), the huggingface_hub package must be installed and the version must be >=0.10.0
inference_config –
Warning
Deprecated. inference_config is deprecated in favor of model_config.
code_paths –
A list of local filesystem paths to Python file dependencies (or directories containing file dependencies). These files are prepended to the system path when the model is loaded. Files declared as dependencies for a given model should have relative imports declared from a common root path if multiple files are defined with import dependencies between them to avoid import errors when loading the model.
For a detailed explanation of
code_paths
functionality, recommended usage patterns and limitations, see the code_paths usage guide.mlflow_model – An MLflow model object that specifies the flavor that this model is being added to.
signature –
A Model Signature object that describes the input and output Schema of the model. The model signature can be inferred using infer_signature function of mlflow.models.signature.
from mlflow.models import infer_signature from mlflow.transformers import generate_signature_output from transformers import pipeline en_to_de = pipeline("translation_en_to_de") data = "MLflow is great!" output = generate_signature_output(en_to_de, data) signature = infer_signature(data, output) mlflow.transformers.save_model( transformers_model=en_to_de, path="/path/to/save/model", signature=signature, input_example=data, ) loaded = mlflow.pyfunc.load_model("/path/to/save/model") print(loaded.predict(data)) # MLflow ist großartig!
If an input_example is provided and the signature is not, a signature will be inferred automatically and applied to the MLmodel file iff the pipeline type is a text-based model (NLP). If the pipeline type is not a supported type, this inference functionality will not function correctly and a warning will be issued. In order to ensure that a precise signature is logged, it is recommended to explicitly provide one.
input_example – one or several instances of valid model input. The input example is used as a hint of what data to feed the model. It will be converted to a Pandas DataFrame and then serialized to json using the Pandas split-oriented format, or a numpy array where the example will be serialized to json by converting it to a list. Bytes are base64-encoded. When the
signature
parameter isNone
, the input example is used to infer a model signature.pip_requirements – Either an iterable of pip requirement strings (e.g.
["transformers", "-r requirements.txt", "-c constraints.txt"]
) or the string path to a pip requirements file on the local filesystem (e.g."requirements.txt"
). If provided, this describes the environment this model should be run in. IfNone
, a default list of requirements is inferred bymlflow.models.infer_pip_requirements()
from the current software environment. If the requirement inference fails, it falls back to usingget_default_pip_requirements()
. Both requirements and constraints are automatically parsed and written torequirements.txt
andconstraints.txt
files, respectively, and stored as part of the model. Requirements are also written to thepip
section of the model’s conda environment (conda.yaml
) file.extra_pip_requirements –
Either an iterable of pip requirement strings (e.g.
["pandas", "-r requirements.txt", "-c constraints.txt"]
) or the string path to a pip requirements file on the local filesystem (e.g."requirements.txt"
). If provided, this describes additional pip requirements that are appended to a default set of pip requirements generated automatically based on the user’s current software environment. Both requirements and constraints are automatically parsed and written torequirements.txt
andconstraints.txt
files, respectively, and stored as part of the model. Requirements are also written to thepip
section of the model’s conda environment (conda.yaml
) file.Warning
The following arguments can’t be specified at the same time:
conda_env
pip_requirements
extra_pip_requirements
This example demonstrates how to specify pip requirements using
pip_requirements
andextra_pip_requirements
.conda_env –
Either a dictionary representation of a Conda environment or the path to a conda environment yaml file. If provided, this describes the environment this model should be run in. At a minimum, it should specify the dependencies contained in
get_default_conda_env()
. IfNone
, a conda environment with pip requirements inferred bymlflow.models.infer_pip_requirements()
is added to the model. If the requirement inference fails, it falls back to usingget_default_pip_requirements()
. pip requirements fromconda_env
are written to a piprequirements.txt
file and the full conda environment is written toconda.yaml
. The following is an example dictionary representation of a conda environment:{ "name": "mlflow-env", "channels": ["conda-forge"], "dependencies": [ "python=3.8.15", { "pip": [ "transformers==x.y.z" ], }, ], }
metadata – Custom metadata dictionary passed to the model and stored in the MLmodel file.
model_config –
A dict of valid overrides that can be applied to a pipeline instance during inference. These arguments are used exclusively for the case of loading the model as a
pyfunc
Model or for use in Spark. These values are not applied to a returned Pipeline from a call tomlflow.transformers.load_model()
Warning
If the key provided is not compatible with either the Pipeline instance for the task provided or is not a valid override to any arguments available in the Model, an Exception will be raised at runtime. It is very important to validate the entries in this dictionary to ensure that they are valid prior to saving or logging.
An example of providing overrides for a question generation model:
from transformers import pipeline, AutoTokenizer task = "text-generation" architecture = "gpt2" sentence_pipeline = pipeline( task=task, tokenizer=AutoTokenizer.from_pretrained(architecture), model=architecture, ) # Validate that the overrides function prompts = ["Generative models are", "I'd like a coconut so that I can"] # validation of config prior to save or log model_config = { "top_k": 2, "num_beams": 5, "max_length": 30, "temperature": 0.62, "top_p": 0.85, "repetition_penalty": 1.15, } # Verify that no exceptions are thrown sentence_pipeline(prompts, **model_config) mlflow.transformers.save_model( transformers_model=sentence_pipeline, path="/path/for/model", task=task, model_config=model_config, )
example_no_conversion – This parameter is deprecated and will be removed in a future release. It’s no longer used and can be safely removed. Input examples are not converted anymore.
prompt_template –
A string that, if provided, will be used to format the user’s input prior to inference. The string should contain a single placeholder,
{prompt}
, which will be replaced with the user’s input. For example:"Answer the following question. Q: {prompt} A:"
.Currently, only the following pipeline types are supported:
save_pretrained –
If set to
False
, MLflow will not save the Transformer model weight files, instead only saving the reference to the HuggingFace Hub model repository and its commit hash. This is useful when you load the pretrained model from HuggingFace Hub and want to log or save it to MLflow without modifying the model weights. In such case, specifying this flag toFalse
will save the storage space and reduce time to save the model. Please refer to the Storage-Efficient Model Logging for more detailed usage.Warning
If the model is saved with
save_pretrained
set toFalse
, the model cannot be registered to the MLflow Model Registry. In order to convert the model to the one that can be registered, you can usemlflow.transformers.persist_pretrained_model()
to download the model weights from the HuggingFace Hub and save it in the existing model artifacts. Please refer to Transformers flavor documentation for more detailed usage.import mlflow.transformers model_uri = "YOUR_MODEL_URI_LOGGED_WITH_SAVE_PRETRAINED_FALSE" model = mlflow.transformers.persist_pretrained_model(model_uri) mlflow.register_model(model_uri, "model_name")
Important
When you save the PEFT model, MLflow will override the save_pretrained flag to False and only store the PEFT adapter weights. The base model weights are not saved but the reference to the HuggingFace repository and its commit hash are logged instead.
kwargs – Optional additional configurations for transformers serialization.