from tensorflow import keras
from tensorflow.keras.callbacks import Callback
from mlflow import log_metrics, log_params, log_text
from mlflow.utils.autologging_utils import ExceptionSafeClass
from mlflow.utils.checkpoint_utils import MlflowModelCheckpointCallbackBase
[docs]class MlflowCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass):
"""Callback for logging Tensorflow training metrics to MLflow.
This callback logs model information at training start, and logs training metrics every epoch or
every n steps (defined by the user) to MLflow.
Args:
log_every_epoch: bool, If True, log metrics every epoch. If False, log metrics every n
steps.
log_every_n_steps: int, log metrics every n steps. If None, log metrics every epoch.
Must be `None` if `log_every_epoch=True`.
.. code-block:: python
:caption: Example
from tensorflow import keras
import mlflow
import numpy as np
# Prepare data for a 2-class classification.
data = tf.random.uniform([8, 28, 28, 3])
label = tf.convert_to_tensor(np.random.randint(2, size=8))
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Flatten(),
keras.layers.Dense(2),
]
)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(0.001),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
with mlflow.start_run() as run:
model.fit(
data,
label,
batch_size=4,
epochs=2,
callbacks=[mlflow.keras.MlflowCallback(run)],
)
"""
def __init__(self, log_every_epoch=True, log_every_n_steps=None):
self.log_every_epoch = log_every_epoch
self.log_every_n_steps = log_every_n_steps
if log_every_epoch and log_every_n_steps is not None:
raise ValueError(
"`log_every_n_steps` must be None if `log_every_epoch=True`, received "
f"`log_every_epoch={log_every_epoch}` and `log_every_n_steps={log_every_n_steps}`."
)
if not log_every_epoch and log_every_n_steps is None:
raise ValueError(
"`log_every_n_steps` must be specified if `log_every_epoch=False`, received"
"`log_every_n_steps=False` and `log_every_n_steps=None`."
)
[docs] def on_train_begin(self, logs=None):
"""Log model architecture and optimizer configuration when training begins."""
config = self.model.optimizer.get_config()
log_params({f"opt_{k}": v for k, v in config.items()})
model_summary = []
def print_fn(line, *args, **kwargs):
model_summary.append(line)
self.model.summary(print_fn=print_fn)
summary = "\n".join(model_summary)
log_text(summary, artifact_file="model_summary.txt")
[docs] def on_epoch_end(self, epoch, logs=None):
"""Log metrics at the end of each epoch."""
if not self.log_every_epoch or logs is None:
return
log_metrics(logs, step=epoch, synchronous=False)
[docs] def on_batch_end(self, batch, logs=None):
"""Log metrics at the end of each batch with user specified frequency."""
if self.log_every_n_steps is None or logs is None:
return
current_iteration = int(self.model.optimizer.iterations.numpy())
if current_iteration % self.log_every_n_steps == 0:
log_metrics(logs, step=current_iteration, synchronous=False)
[docs] def on_test_end(self, logs=None):
"""Log validation metrics at validation end."""
if logs is None:
return
metrics = {"validation_" + k: v for k, v in logs.items()}
log_metrics(metrics, synchronous=False)
class MlflowModelCheckpointCallback(Callback, MlflowModelCheckpointCallbackBase):
"""Callback for automatic Keras model checkpointing to MLflow.
Args:
monitor: In automatic model checkpointing, the metric name to monitor if
you set `model_checkpoint_save_best_only` to True.
save_best_only: If True, automatic model checkpointing only saves when
the model is considered the "best" model according to the quantity
monitored and previous checkpoint model is overwritten.
mode: one of {"min", "max"}. In automatic model checkpointing,
if save_best_only=True, the decision to overwrite the current save file is made
based on either the maximization or the minimization of the monitored quantity.
save_weights_only: In automatic model checkpointing, if True, then
only the model’s weights will be saved. Otherwise, the optimizer states,
lr-scheduler states, etc are added in the checkpoint too.
save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
saves the model after each epoch. When using integer, the callback
saves the model at end of this many batches. Note that if the saving isn't
aligned to epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.
.. code-block:: python
:caption: Example
from tensorflow import keras
import tensorflow as tf
import mlflow
import numpy as np
from mlflow.tensorflow import MlflowModelCheckpointCallback
# Prepare data for a 2-class classification.
data = tf.random.uniform([8, 28, 28, 3])
label = tf.convert_to_tensor(np.random.randint(2, size=8))
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Flatten(),
keras.layers.Dense(2),
]
)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(0.001),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
mlflow_checkpoint_callback = MlflowModelCheckpointCallback(
monitor="sparse_categorical_accuracy",
mode="max",
save_best_only=True,
save_weights_only=False,
save_freq="epoch",
)
with mlflow.start_run() as run:
model.fit(
data,
label,
batch_size=4,
epochs=2,
callbacks=[mlflow_checkpoint_callback],
)
"""
def __init__(
self,
monitor="val_loss",
mode="min",
save_best_only=True,
save_weights_only=False,
save_freq="epoch",
):
Callback.__init__(self)
MlflowModelCheckpointCallbackBase.__init__(
self,
checkpoint_file_suffix=".h5",
monitor=monitor,
mode=mode,
save_best_only=save_best_only,
save_weights_only=save_weights_only,
save_freq=save_freq,
)
self.trainer = None
self.current_epoch = None
self._last_batch_seen = 0
self.global_step = 0
self.global_step_last_saving = 0
def save_checkpoint(self, filepath: str):
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
def on_epoch_begin(self, epoch, logs=None):
self.current_epoch = epoch
def on_train_batch_end(self, batch, logs=None):
# Note that `on_train_batch_end` might be invoked by every N train steps,
# (controlled by `steps_per_execution` argument in `model.compile` method).
# the following logic is similar to
# https://github.com/keras-team/keras/blob/e6e62405fa1b4444102601636d871610d91e5783/keras/callbacks/model_checkpoint.py#L212
add_batches = batch + 1 if batch <= self._last_batch_seen else batch - self._last_batch_seen
self._last_batch_seen = batch
self.global_step += add_batches
if isinstance(self.save_freq, int):
if self.global_step - self.global_step_last_saving >= self.save_freq:
self.check_and_save_checkpoint_if_needed(
current_epoch=self.current_epoch,
global_step=self.global_step,
metric_dict={k: float(v) for k, v in logs.items()},
)
self.global_step_last_saving = self.global_step
def on_epoch_end(self, epoch, logs=None):
if self.save_freq == "epoch":
self.check_and_save_checkpoint_if_needed(
current_epoch=self.current_epoch,
global_step=self.global_step,
metric_dict={k: float(v) for k, v in logs.items()},
)