| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- from __future__ import annotations
- import abc
- from typing import Any
- from tensorflow.keras.callbacks import Callback # type: ignore
- import wandb
- from wandb.sdk.lib import telemetry
- class WandbEvalCallback(Callback, abc.ABC):
- """Abstract base class to build Keras callbacks for model prediction visualization.
- You can build callbacks for visualizing model predictions `on_epoch_end`
- that can be passed to `model.fit()` for classification, object detection,
- segmentation, etc. tasks.
- To use this, inherit from this base callback class and implement the
- `add_ground_truth` and `add_model_prediction` methods.
- The base class will take care of the following:
- - Initialize `data_table` for logging the ground truth and
- `pred_table` for predictions.
- - The data uploaded to `data_table` is used as a reference for the
- `pred_table`. This is to reduce the memory footprint. The `data_table_ref`
- is a list that can be used to access the referenced data.
- Check out the example below to see how it's done.
- - Log the tables to W&B as W&B Artifacts.
- - Each new `pred_table` is logged as a new version with aliases.
- Example:
- ```python
- class WandbClfEvalCallback(WandbEvalCallback):
- def __init__(self, validation_data, data_table_columns, pred_table_columns):
- super().__init__(data_table_columns, pred_table_columns)
- self.x = validation_data[0]
- self.y = validation_data[1]
- def add_ground_truth(self):
- for idx, (image, label) in enumerate(zip(self.x, self.y)):
- self.data_table.add_data(idx, wandb.Image(image), label)
- def add_model_predictions(self, epoch):
- preds = self.model.predict(self.x, verbose=0)
- preds = tf.argmax(preds, axis=-1)
- data_table_ref = self.data_table_ref
- table_idxs = data_table_ref.get_index()
- for idx in table_idxs:
- pred = preds[idx]
- self.pred_table.add_data(
- epoch,
- data_table_ref.data[idx][0],
- data_table_ref.data[idx][1],
- data_table_ref.data[idx][2],
- pred,
- )
- model.fit(
- x,
- y,
- epochs=2,
- validation_data=(x, y),
- callbacks=[
- WandbClfEvalCallback(
- validation_data=(x, y),
- data_table_columns=["idx", "image", "label"],
- pred_table_columns=["epoch", "idx", "image", "label", "pred"],
- )
- ],
- )
- ```
- To have more fine-grained control, you can override the `on_train_begin` and
- `on_epoch_end` methods. If you want to log the samples after N batched, you
- can implement `on_train_batch_end` method.
- """
- def __init__(
- self,
- data_table_columns: list[str],
- pred_table_columns: list[str],
- *args: Any,
- **kwargs: Any,
- ) -> None:
- super().__init__(*args, **kwargs)
- if wandb.run is None:
- raise wandb.Error(
- "You must call `wandb.init()` first before using this callback."
- )
- with telemetry.context(run=wandb.run) as tel:
- tel.feature.keras_wandb_eval_callback = True
- self.data_table_columns = data_table_columns
- self.pred_table_columns = pred_table_columns
- def on_train_begin(self, logs: dict[str, float] | None = None) -> None:
- # Initialize the data_table
- self.init_data_table(column_names=self.data_table_columns)
- # Log the ground truth data
- self.add_ground_truth(logs)
- # Log the data_table as W&B Artifacts
- self.log_data_table()
- def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None:
- # Initialize the pred_table
- self.init_pred_table(column_names=self.pred_table_columns)
- # Log the model prediction
- self.add_model_predictions(epoch, logs)
- # Log the pred_table as W&B Artifacts
- self.log_pred_table()
- @abc.abstractmethod
- def add_ground_truth(self, logs: dict[str, float] | None = None) -> None:
- """Add ground truth data to `data_table`.
- Use this method to write the logic for adding validation/training data to
- `data_table` initialized using `init_data_table` method.
- Example:
- ```python
- for idx, data in enumerate(dataloader):
- self.data_table.add_data(idx, data)
- ```
- This method is called once `on_train_begin` or equivalent hook.
- """
- raise NotImplementedError(f"{self.__class__.__name__}.add_ground_truth")
- @abc.abstractmethod
- def add_model_predictions(
- self, epoch: int, logs: dict[str, float] | None = None
- ) -> None:
- """Add a prediction from a model to `pred_table`.
- Use this method to write the logic for adding model prediction for validation/
- training data to `pred_table` initialized using `init_pred_table` method.
- Example:
- ```python
- # Assuming the dataloader is not shuffling the samples.
- for idx, data in enumerate(dataloader):
- preds = model.predict(data)
- self.pred_table.add_data(
- self.data_table_ref.data[idx][0],
- self.data_table_ref.data[idx][1],
- preds,
- )
- ```
- This method is called `on_epoch_end` or equivalent hook.
- """
- raise NotImplementedError(f"{self.__class__.__name__}.add_model_predictions")
- def init_data_table(self, column_names: list[str]) -> None:
- """Initialize the W&B Tables for validation data.
- Call this method `on_train_begin` or equivalent hook. This is followed by adding
- data to the table row or column wise.
- Args:
- column_names: (list) Column names for W&B Tables.
- """
- self.data_table = wandb.Table(columns=column_names, allow_mixed_types=True)
- def init_pred_table(self, column_names: list[str]) -> None:
- """Initialize the W&B Tables for model evaluation.
- Call this method `on_epoch_end` or equivalent hook. This is followed by adding
- data to the table row or column wise.
- Args:
- column_names: (list) Column names for W&B Tables.
- """
- self.pred_table = wandb.Table(columns=column_names)
- def log_data_table(
- self, name: str = "val", type: str = "dataset", table_name: str = "val_data"
- ) -> None:
- """Log the `data_table` as W&B artifact and call `use_artifact` on it.
- This lets the evaluation table use the reference of already uploaded data
- (images, text, scalar, etc.) without re-uploading.
- Args:
- name: (str) A human-readable name for this artifact, which is how you can
- identify this artifact in the UI or reference it in use_artifact calls.
- (default is 'val')
- type: (str) The type of the artifact, which is used to organize and
- differentiate artifacts. (default is 'dataset')
- table_name: (str) The name of the table as will be displayed in the UI.
- (default is 'val_data').
- """
- data_artifact = wandb.Artifact(name, type=type)
- data_artifact.add(self.data_table, table_name)
- # Calling `use_artifact` uploads the data to W&B.
- assert wandb.run is not None
- wandb.run.use_artifact(data_artifact)
- data_artifact.wait()
- # We get the reference table.
- self.data_table_ref = data_artifact.get(table_name)
- def log_pred_table(
- self,
- type: str = "evaluation",
- table_name: str = "eval_data",
- aliases: list[str] | None = None,
- ) -> None:
- """Log the W&B Tables for model evaluation.
- The table will be logged multiple times creating new version. Use this
- to compare models at different intervals interactively.
- Args:
- type: (str) The type of the artifact, which is used to organize and
- differentiate artifacts. (default is 'evaluation')
- table_name: (str) The name of the table as will be displayed in the UI.
- (default is 'eval_data')
- aliases: (List[str]) List of aliases for the prediction table.
- """
- assert wandb.run is not None
- pred_artifact = wandb.Artifact(f"run_{wandb.run.id}_pred", type=type)
- pred_artifact.add(self.pred_table, table_name)
- wandb.run.log_artifact(pred_artifact, aliases=aliases or ["latest"])
|