tables_builder.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from __future__ import annotations
  2. import abc
  3. from typing import Any
  4. from tensorflow.keras.callbacks import Callback # type: ignore
  5. import wandb
  6. from wandb.sdk.lib import telemetry
  7. class WandbEvalCallback(Callback, abc.ABC):
  8. """Abstract base class to build Keras callbacks for model prediction visualization.
  9. You can build callbacks for visualizing model predictions `on_epoch_end`
  10. that can be passed to `model.fit()` for classification, object detection,
  11. segmentation, etc. tasks.
  12. To use this, inherit from this base callback class and implement the
  13. `add_ground_truth` and `add_model_prediction` methods.
  14. The base class will take care of the following:
  15. - Initialize `data_table` for logging the ground truth and
  16. `pred_table` for predictions.
  17. - The data uploaded to `data_table` is used as a reference for the
  18. `pred_table`. This is to reduce the memory footprint. The `data_table_ref`
  19. is a list that can be used to access the referenced data.
  20. Check out the example below to see how it's done.
  21. - Log the tables to W&B as W&B Artifacts.
  22. - Each new `pred_table` is logged as a new version with aliases.
  23. Example:
  24. ```python
  25. class WandbClfEvalCallback(WandbEvalCallback):
  26. def __init__(self, validation_data, data_table_columns, pred_table_columns):
  27. super().__init__(data_table_columns, pred_table_columns)
  28. self.x = validation_data[0]
  29. self.y = validation_data[1]
  30. def add_ground_truth(self):
  31. for idx, (image, label) in enumerate(zip(self.x, self.y)):
  32. self.data_table.add_data(idx, wandb.Image(image), label)
  33. def add_model_predictions(self, epoch):
  34. preds = self.model.predict(self.x, verbose=0)
  35. preds = tf.argmax(preds, axis=-1)
  36. data_table_ref = self.data_table_ref
  37. table_idxs = data_table_ref.get_index()
  38. for idx in table_idxs:
  39. pred = preds[idx]
  40. self.pred_table.add_data(
  41. epoch,
  42. data_table_ref.data[idx][0],
  43. data_table_ref.data[idx][1],
  44. data_table_ref.data[idx][2],
  45. pred,
  46. )
  47. model.fit(
  48. x,
  49. y,
  50. epochs=2,
  51. validation_data=(x, y),
  52. callbacks=[
  53. WandbClfEvalCallback(
  54. validation_data=(x, y),
  55. data_table_columns=["idx", "image", "label"],
  56. pred_table_columns=["epoch", "idx", "image", "label", "pred"],
  57. )
  58. ],
  59. )
  60. ```
  61. To have more fine-grained control, you can override the `on_train_begin` and
  62. `on_epoch_end` methods. If you want to log the samples after N batched, you
  63. can implement `on_train_batch_end` method.
  64. """
  65. def __init__(
  66. self,
  67. data_table_columns: list[str],
  68. pred_table_columns: list[str],
  69. *args: Any,
  70. **kwargs: Any,
  71. ) -> None:
  72. super().__init__(*args, **kwargs)
  73. if wandb.run is None:
  74. raise wandb.Error(
  75. "You must call `wandb.init()` first before using this callback."
  76. )
  77. with telemetry.context(run=wandb.run) as tel:
  78. tel.feature.keras_wandb_eval_callback = True
  79. self.data_table_columns = data_table_columns
  80. self.pred_table_columns = pred_table_columns
  81. def on_train_begin(self, logs: dict[str, float] | None = None) -> None:
  82. # Initialize the data_table
  83. self.init_data_table(column_names=self.data_table_columns)
  84. # Log the ground truth data
  85. self.add_ground_truth(logs)
  86. # Log the data_table as W&B Artifacts
  87. self.log_data_table()
  88. def on_epoch_end(self, epoch: int, logs: dict[str, float] | None = None) -> None:
  89. # Initialize the pred_table
  90. self.init_pred_table(column_names=self.pred_table_columns)
  91. # Log the model prediction
  92. self.add_model_predictions(epoch, logs)
  93. # Log the pred_table as W&B Artifacts
  94. self.log_pred_table()
  95. @abc.abstractmethod
  96. def add_ground_truth(self, logs: dict[str, float] | None = None) -> None:
  97. """Add ground truth data to `data_table`.
  98. Use this method to write the logic for adding validation/training data to
  99. `data_table` initialized using `init_data_table` method.
  100. Example:
  101. ```python
  102. for idx, data in enumerate(dataloader):
  103. self.data_table.add_data(idx, data)
  104. ```
  105. This method is called once `on_train_begin` or equivalent hook.
  106. """
  107. raise NotImplementedError(f"{self.__class__.__name__}.add_ground_truth")
  108. @abc.abstractmethod
  109. def add_model_predictions(
  110. self, epoch: int, logs: dict[str, float] | None = None
  111. ) -> None:
  112. """Add a prediction from a model to `pred_table`.
  113. Use this method to write the logic for adding model prediction for validation/
  114. training data to `pred_table` initialized using `init_pred_table` method.
  115. Example:
  116. ```python
  117. # Assuming the dataloader is not shuffling the samples.
  118. for idx, data in enumerate(dataloader):
  119. preds = model.predict(data)
  120. self.pred_table.add_data(
  121. self.data_table_ref.data[idx][0],
  122. self.data_table_ref.data[idx][1],
  123. preds,
  124. )
  125. ```
  126. This method is called `on_epoch_end` or equivalent hook.
  127. """
  128. raise NotImplementedError(f"{self.__class__.__name__}.add_model_predictions")
  129. def init_data_table(self, column_names: list[str]) -> None:
  130. """Initialize the W&B Tables for validation data.
  131. Call this method `on_train_begin` or equivalent hook. This is followed by adding
  132. data to the table row or column wise.
  133. Args:
  134. column_names: (list) Column names for W&B Tables.
  135. """
  136. self.data_table = wandb.Table(columns=column_names, allow_mixed_types=True)
  137. def init_pred_table(self, column_names: list[str]) -> None:
  138. """Initialize the W&B Tables for model evaluation.
  139. Call this method `on_epoch_end` or equivalent hook. This is followed by adding
  140. data to the table row or column wise.
  141. Args:
  142. column_names: (list) Column names for W&B Tables.
  143. """
  144. self.pred_table = wandb.Table(columns=column_names)
  145. def log_data_table(
  146. self, name: str = "val", type: str = "dataset", table_name: str = "val_data"
  147. ) -> None:
  148. """Log the `data_table` as W&B artifact and call `use_artifact` on it.
  149. This lets the evaluation table use the reference of already uploaded data
  150. (images, text, scalar, etc.) without re-uploading.
  151. Args:
  152. name: (str) A human-readable name for this artifact, which is how you can
  153. identify this artifact in the UI or reference it in use_artifact calls.
  154. (default is 'val')
  155. type: (str) The type of the artifact, which is used to organize and
  156. differentiate artifacts. (default is 'dataset')
  157. table_name: (str) The name of the table as will be displayed in the UI.
  158. (default is 'val_data').
  159. """
  160. data_artifact = wandb.Artifact(name, type=type)
  161. data_artifact.add(self.data_table, table_name)
  162. # Calling `use_artifact` uploads the data to W&B.
  163. assert wandb.run is not None
  164. wandb.run.use_artifact(data_artifact)
  165. data_artifact.wait()
  166. # We get the reference table.
  167. self.data_table_ref = data_artifact.get(table_name)
  168. def log_pred_table(
  169. self,
  170. type: str = "evaluation",
  171. table_name: str = "eval_data",
  172. aliases: list[str] | None = None,
  173. ) -> None:
  174. """Log the W&B Tables for model evaluation.
  175. The table will be logged multiple times creating new version. Use this
  176. to compare models at different intervals interactively.
  177. Args:
  178. type: (str) The type of the artifact, which is used to organize and
  179. differentiate artifacts. (default is 'evaluation')
  180. table_name: (str) The name of the table as will be displayed in the UI.
  181. (default is 'eval_data')
  182. aliases: (List[str]) List of aliases for the prediction table.
  183. """
  184. assert wandb.run is not None
  185. pred_artifact = wandb.Artifact(f"run_{wandb.run.id}_pred", type=type)
  186. pred_artifact.add(self.pred_table, table_name)
  187. wandb.run.log_artifact(pred_artifact, aliases=aliases or ["latest"])