"""Keras 3 callback to log information to MLflow."""
import keras
from mlflow import log_metrics, log_params, log_text
from mlflow.utils.annotations import experimental
from mlflow.utils.autologging_utils import ExceptionSafeClass
[docs]@experimental
class MlflowCallback(keras.callbacks.Callback, metaclass=ExceptionSafeClass):
"""Callback for logging Keras metrics/params/model/... to MLflow.
This callback logs model metadata at training begins, and logs training metrics every epoch or
every n steps (defined by the user) to MLflow.
Args:
log_every_epoch: bool, defaults to True. If True, log metrics every epoch. If False,
log metrics every n steps.
log_every_n_steps: int, defaults to None. If set, log metrics every n steps. If None,
log metrics every epoch. Must be `None` if `log_every_epoch=True`.
.. code-block:: python
:caption: Example
import keras
import mlflow
import numpy as np
# Prepare data for a 2-class classification.
data = np.random.uniform([8, 28, 28, 3])
label = np.random.randint(2, size=8)
model = keras.Sequential(
[
keras.Input([28, 28, 3]),
keras.layers.Flatten(),
keras.layers.Dense(2),
]
)
model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=keras.optimizers.Adam(0.001),
metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
with mlflow.start_run() as run:
model.fit(
data,
label,
batch_size=4,
epochs=2,
callbacks=[mlflow.keras.MlflowCallback()],
)
"""
def __init__(self, log_every_epoch=True, log_every_n_steps=None):
self.log_every_epoch = log_every_epoch
self.log_every_n_steps = log_every_n_steps
if log_every_epoch and log_every_n_steps is not None:
raise ValueError(
"`log_every_n_steps` must be None if `log_every_epoch=True`, received "
f"`log_every_epoch={log_every_epoch}` and `log_every_n_steps={log_every_n_steps}`."
)
if not log_every_epoch and log_every_n_steps is None:
raise ValueError(
"`log_every_n_steps` must be specified if `log_every_epoch=False`, received"
"`log_every_n_steps=False` and `log_every_n_steps=None`."
)
[docs] def on_train_begin(self, logs=None):
"""Log model architecture and optimizer configuration when training begins."""
config = self.model.optimizer.get_config()
log_params({f"optimizer_{k}": v for k, v in config.items()})
model_summary = []
def print_fn(line, *args, **kwargs):
model_summary.append(line)
self.model.summary(print_fn=print_fn)
summary = "\n".join(model_summary)
log_text(summary, artifact_file="model_summary.txt")
[docs] def on_epoch_end(self, epoch, logs=None):
"""Log metrics at the end of each epoch."""
if not self.log_every_epoch or logs is None:
return
log_metrics(logs, step=epoch, synchronous=False)
[docs] def on_batch_end(self, batch, logs=None):
"""Log metrics at the end of each batch with user specified frequency."""
if self.log_every_n_steps is None or logs is None:
return
current_iteration = int(self.model.optimizer.iterations.numpy())
if current_iteration % self.log_every_n_steps == 0:
log_metrics(logs, step=current_iteration, synchronous=False)
[docs] def on_test_end(self, logs=None):
"""Log validation metrics at validation end."""
if logs is None:
return
metrics = {"validation_" + k: v for k, v in logs.items()}
log_metrics(metrics, synchronous=False)