Get Started with MLflow + Tensorflow

In this guide, we will show how to train your model with Tensorflow and log your training using MLflow.

We will use Databricks Community Edition as our tracking server, which has built-in support for MLflow. Databricks CE is the free version of Databricks platform, if you haven’t, please register an account via link.

You can run code in this guide from cloud-based notebooks like Databricks notebook or Google Colab, or run it on your local machine.

Install dependencies

Let’s install the mlflow package.

%pip install -q mlflow

Then let’s import the packages.

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras

Load the dataset

We will do a simple image classification on handwritten digits with mnist dataset.

Let’s load the dataset using tensorflow_datasets (tfds), which returns datasets in the format of

# Load the mnist dataset.
train_ds, test_ds = tfds.load(
    split=["train", "test"],
Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.

Let’s preprocess our data with the following steps: - Scale each pixel’s value to [0, 1). - Batch the dataset. - Use prefetch to speed up the training.

def preprocess_fn(data):
    image = tf.cast(data["image"], tf.float32) / 255
    label = data["label"]
    return (image, label)

train_ds =
test_ds =

Define the Model

Let’s define a convolutional neural network as our classifier. We can use keras.Sequential to stack up the layers.

input_shape = (28, 28, 1)
num_classes = 10

model = keras.Sequential(
        keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Dense(num_classes, activation="softmax"),

Set training-related configs, optimizers, loss function, metrics.


Set up tracking/visualization tool

In this tutorial, we will use Databricks CE as MLflow tracking server. For other options such as using your local MLflow server, please read the Tracking Server Overview.

If you have not, please register an account of Databricks community edition. It should take no longer than 1min to register. Databricks CE (community edition) is a free platform for users to try out Databricks features. For this guide, we need the ML experiment dashboard for us to track our training progress.

After successfully registering an account on Databricks CE, let’s connnect MLflow to Databricks CE. You will need to enter following information: - Databricks Host: - Username: your signed up email - Password: your password

import mlflow


Now this colab is connected to the hosted tracking server. Let’s configure MLflow metadata. Two things to set up: - mlflow.set_tracking_uri: always use “databricks”. - mlflow.set_experiment: pick up a name you like, start with /.

Logging with MLflow

There are two ways you can log to MLflow from your Tensorflow pipeline: - MLflow auto logging. - Use a callback.

Auto logging is simple to configure, but gives you less control. Using a callback is more flexible. Let’s see how each way is done.

MLflow Auto Logging

All you need to do is to call mlflow.tensorflow.autolog() before kicking off the training, then the backend will automatically log the metrics into the server you configured earlier. In our case, Databricks CE.

# Choose any name that you like.

mlflow.tensorflow.autolog(), epochs=3)
2023/11/15 01:53:35 INFO mlflow.utils.autologging_utils: Created MLflow autologging run with ID '7c1db53e417b43f0a1d9e095c9943acb', which will track hyperparameters, performance metrics, model artifacts, and lineage information for the current tensorflow workflow
Epoch 1/3
469/469 [==============================] - 13s 7ms/step - loss: 0.3610 - sparse_categorical_accuracy: 0.8890
Epoch 2/3
469/469 [==============================] - 3s 6ms/step - loss: 0.1035 - sparse_categorical_accuracy: 0.9681
Epoch 3/3
469/469 [==============================] - 4s 8ms/step - loss: 0.0798 - sparse_categorical_accuracy: 0.9760
2023/11/15 01:54:05 WARNING mlflow.tensorflow: Failed to infer model signature: could not sample data to infer model signature: tuple index out of range
2023/11/15 01:54:05 WARNING mlflow.models.model: Model logged without a signature. Signatures will be required for upcoming model registry features as they validate model inputs and denote the expected schema of model outputs. Please visit for instructions on setting a model signature on your logged model.
2023/11/15 01:54:05 WARNING mlflow.tensorflow: You are saving a TensorFlow Core model or Keras model without a signature. Inference with mlflow.pyfunc.spark_udf() will not work unless the model's pyfunc representation accepts pandas DataFrames as inference inputs.
2023/11/15 01:54:13 WARNING mlflow.utils.autologging_utils: MLflow autologging encountered a warning: "/usr/local/lib/python3.10/dist-packages/_distutils_hack/ UserWarning: Setuptools is replacing distutils."
2023/11/15 01:54:13 INFO The progress bar can be disabled by setting the environment variable MLFLOW_ENABLE_ARTIFACTS_PROGRESS_BAR to false
<keras.src.callbacks.History at 0x7d48e6556b60>

While your training is ongoing, you can find this training in your dashboard. Log in to your Databricks CE account, and click on top left to select machine learning in the drop down list. Then click on the experiment icon. See the screenshot below: landing page

After clicking the Experiment button, it will bring you to the experiment page, where you can find your runs. Clicking on the most recent experiment and run, you can find your metrics there, similar to: experiment page

You can click on metrics to see the chart.

Let’s evaluate the training result.

score = model.evaluate(test_ds)

print(f"Test loss: {score[0]:.4f}")
print(f"Test accuracy: {score[1]: .2f}")
79/79 [==============================] - 1s 12ms/step - loss: 0.0484 - sparse_categorical_accuracy: 0.9838
Test loss: 0.05
Test accuracy:  0.98

Log with MLflow Callback

Auto logging is powerful and convenient, but if you are looking for a more native way as Tensorflow pipelines, you can use mlflow.tensorflow.MllflowCallback inside, it will log: - Your model configuration, layers, hyperparameters and so on. - The training stats, including losses and metrics configured with model.compile().

from mlflow.tensorflow import MlflowCallback

# Turn off autologging.

with mlflow.start_run() as run:
Epoch 1/2
469/469 [==============================] - 5s 10ms/step - loss: 0.0473 - sparse_categorical_accuracy: 0.9851
Epoch 2/2
469/469 [==============================] - 4s 8ms/step - loss: 0.0432 - sparse_categorical_accuracy: 0.9866

Going to the Databricks CE experiment view, you will see a similar dashboard as before.

Customize the MLflow Callback

If you want to add extra logging logic, you can customize the MLflow callback. You can either subclass from keras.callbacks.Callback and write everything from scratch or subclass from mlflow.tensorflow.MllflowCallback to add you custom logging logic.

Let’s look at an example that we want to replace the loss with its log value to log to MLflow.

import math

# Create our own callback by subclassing `MlflowCallback`.
class MlflowCustomCallback(MlflowCallback):
    def on_epoch_end(self, epoch, logs=None):
        if not self.log_every_epoch:
        loss = logs["loss"]
        logs["log_loss"] = math.log(loss)
        del logs["loss"]
        self.metrics_logger.record_metrics(logs, epoch)

Train the model with the new callback.

with mlflow.start_run() as run:
    run_id =
Epoch 1/2
469/469 [==============================] - 5s 10ms/step - loss: 0.0537 - sparse_categorical_accuracy: 0.9834 - log_loss: -2.9237
Epoch 2/2
469/469 [==============================] - 4s 9ms/step - loss: 0.0497 - sparse_categorical_accuracy: 0.9846 - log_loss: -3.0022
2023/11/15 01:57:50 WARNING mlflow.utils.autologging_utils: Encountered unexpected error during tensorflow autologging: MLflow autologging must be turned off if an `MllflowCallback` is explicitly added to the callback list. You are creating an `MllflowCallback` while having autologging enabled. Please either call `mlflow.tensorflow.autolog(disable=True)` to disable autologging or remove `MllflowCallback` from the callback list.

Going to your Databricks CE page, you should find the log_loss is replacing the loss metric, similar to what is shown in the screenshot below.

log loss screenshot

Wrap up

Now you have learned the basic integration between MLflow and Tensorflow. There are a few things not covered by this quickstart, e.g., saving TF model to MLflow and loading it back. For a detailed guide, please refer to our main guide for integration between MLflow and Tensorflow.