keras.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084
  1. """keras init."""
  2. import logging
  3. import operator
  4. import os
  5. import shutil
  6. import sys
  7. from itertools import chain
  8. import numpy as np
  9. import tensorflow as tf
  10. import tensorflow.keras.backend as K # noqa: N812
  11. import wandb
  12. from wandb.proto.wandb_telemetry_pb2 import Deprecated
  13. from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
  14. from wandb.sdk.lib import telemetry
  15. from wandb.sdk.lib.deprecation import warn_and_record_deprecation
  16. from wandb.util import add_import_hook
  17. def _check_keras_version():
  18. from keras import __version__ as keras_version
  19. from packaging.version import parse
  20. if parse(keras_version) < parse("2.4.0"):
  21. wandb.termwarn(
  22. f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
  23. )
  24. def _can_compute_flops() -> bool:
  25. """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
  26. from packaging.version import parse
  27. return parse(tf.__version__) >= parse("2.0.0")
  28. if "keras" in sys.modules:
  29. _check_keras_version()
  30. else:
  31. add_import_hook("keras", _check_keras_version)
  32. logger = logging.getLogger(__name__)
  33. def is_dataset(data):
  34. dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
  35. if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
  36. dataset_types = (dataset_ops.DatasetV2,)
  37. if hasattr(dataset_ops, "DatasetV1"):
  38. dataset_types = dataset_types + (dataset_ops.DatasetV1,)
  39. return isinstance(data, dataset_types)
  40. else:
  41. return False
  42. def is_generator_like(data):
  43. # Checks if data is a generator, Sequence, or Iterator.
  44. types = (tf.keras.utils.Sequence,)
  45. iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
  46. if iterator_ops:
  47. types = types + (iterator_ops.Iterator,)
  48. # EagerIterator was in tensorflow < 2
  49. if hasattr(iterator_ops, "EagerIterator"):
  50. types = types + (iterator_ops.EagerIterator,)
  51. elif hasattr(iterator_ops, "IteratorV2"):
  52. types = types + (iterator_ops.IteratorV2,)
  53. return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)
  54. def patch_tf_keras(): # noqa: C901
  55. from packaging.version import parse
  56. from tensorflow.python.eager import context
  57. if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"):
  58. keras_engine = "keras.engine"
  59. try:
  60. from keras.engine import training
  61. from keras.engine import training_arrays_v1 as training_arrays
  62. from keras.engine import training_generator_v1 as training_generator
  63. except (ImportError, AttributeError):
  64. wandb.termerror("Unable to patch Tensorflow/Keras")
  65. logger.exception("exception while trying to patch_tf_keras")
  66. return
  67. else:
  68. keras_engine = "tensorflow.python.keras.engine"
  69. from tensorflow.python.keras.engine import training
  70. try:
  71. from tensorflow.python.keras.engine import (
  72. training_arrays_v1 as training_arrays,
  73. )
  74. from tensorflow.python.keras.engine import (
  75. training_generator_v1 as training_generator,
  76. )
  77. except (ImportError, AttributeError):
  78. try:
  79. from tensorflow.python.keras.engine import (
  80. training_arrays,
  81. training_generator,
  82. )
  83. except (ImportError, AttributeError):
  84. wandb.termerror("Unable to patch Tensorflow/Keras")
  85. logger.exception("exception while trying to patch_tf_keras")
  86. return
  87. # Tensorflow 2.1
  88. training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
  89. # Tensorflow 2.2
  90. training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")
  91. if training_v2_1:
  92. old_v2 = training_v2_1.Loop.fit
  93. elif training_v2_2:
  94. old_v2 = training.Model.fit
  95. old_arrays = training_arrays.fit_loop
  96. old_generator = training_generator.fit_generator
  97. def set_wandb_attrs(cbk, val_data):
  98. if isinstance(cbk, WandbCallback):
  99. if is_generator_like(val_data):
  100. cbk.generator = val_data
  101. elif is_dataset(val_data):
  102. if context.executing_eagerly():
  103. cbk.generator = iter(val_data)
  104. else:
  105. wandb.termwarn(
  106. "Found a validation dataset in graph mode, can't patch Keras."
  107. )
  108. elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
  109. # Graph mode dataset generator
  110. def gen():
  111. while True:
  112. yield K.get_session().run(val_data)
  113. cbk.generator = gen()
  114. else:
  115. cbk.validation_data = val_data
  116. def new_arrays(*args, **kwargs):
  117. cbks = kwargs.get("callbacks", [])
  118. val_inputs = kwargs.get("val_inputs")
  119. val_targets = kwargs.get("val_targets")
  120. # TODO: these could be generators, why index 0?
  121. if val_inputs and val_targets:
  122. for cbk in cbks:
  123. set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
  124. return old_arrays(*args, **kwargs)
  125. def new_generator(*args, **kwargs):
  126. cbks = kwargs.get("callbacks", [])
  127. val_data = kwargs.get("validation_data")
  128. if val_data:
  129. for cbk in cbks:
  130. set_wandb_attrs(cbk, val_data)
  131. return old_generator(*args, **kwargs)
  132. def new_v2(*args, **kwargs):
  133. cbks = kwargs.get("callbacks", [])
  134. val_data = kwargs.get("validation_data")
  135. if val_data:
  136. for cbk in cbks:
  137. set_wandb_attrs(cbk, val_data)
  138. return old_v2(*args, **kwargs)
  139. training_arrays.orig_fit_loop = old_arrays
  140. training_arrays.fit_loop = new_arrays
  141. training_generator.orig_fit_generator = old_generator
  142. training_generator.fit_generator = new_generator
  143. wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
  144. wandb.patched["keras"].append(
  145. [f"{keras_engine}.training_generator", "fit_generator"]
  146. )
  147. if training_v2_1:
  148. training_v2_1.Loop.fit = new_v2
  149. wandb.patched["keras"].append(
  150. ["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
  151. )
  152. elif training_v2_2:
  153. training.Model.fit = new_v2
  154. wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])
  155. def _array_has_dtype(array):
  156. return hasattr(array, "dtype")
  157. def _update_if_numeric(metrics, key, values):
  158. if not _array_has_dtype(values):
  159. _warn_not_logging(key)
  160. return
  161. if not is_numeric_array(values):
  162. _warn_not_logging_non_numeric(key)
  163. return
  164. metrics[key] = wandb.Histogram(values)
  165. def is_numeric_array(array):
  166. return np.issubdtype(array.dtype, np.number)
  167. def _warn_not_logging_non_numeric(name):
  168. wandb.termwarn(
  169. f"Non-numeric values found in layer: {name}, not logging this layer",
  170. repeat=False,
  171. )
  172. def _warn_not_logging(name):
  173. wandb.termwarn(
  174. f"Layer {name} has undetermined datatype not logging this layer",
  175. repeat=False,
  176. )
  177. tf_logger = tf.get_logger()
  178. patch_tf_keras()
  179. ### For gradient logging ###
  180. def _get_custom_optimizer_parent_class():
  181. from packaging.version import parse
  182. if parse(tf.__version__) >= parse("2.9.0"):
  183. custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
  184. else:
  185. custom_optimizer_parent_class = tf.keras.optimizers.Optimizer
  186. return custom_optimizer_parent_class
  187. _custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
  188. class _CustomOptimizer(_custom_optimizer_parent_class):
  189. def __init__(self):
  190. super().__init__(name="CustomOptimizer")
  191. self._resource_apply_dense = tf.function(self._resource_apply_dense)
  192. self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
  193. def _resource_apply_dense(self, grad, var):
  194. var.assign(grad)
  195. # this needs to be implemented to prevent a NotImplementedError when
  196. # using Lookup layers.
  197. def _resource_apply_sparse(self, grad, var, indices):
  198. pass
  199. def get_config(self):
  200. return super().get_config()
  201. class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
  202. """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""
  203. def set_model(self, model):
  204. super().set_model(model)
  205. self.og_weights = model.get_weights()
  206. self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]
  207. def on_batch_end(self, batch, logs=None):
  208. for g, w in zip(self.grads, self.model.trainable_weights):
  209. g += w.numpy()
  210. self.model.set_weights(self.og_weights)
  211. def get_grads(self):
  212. return [g.copy() for g in self.grads]
  213. ###
  214. class WandbCallback(tf.keras.callbacks.Callback):
  215. """`WandbCallback` automatically integrates keras with wandb.
  216. Example:
  217. ```python
  218. model.fit(
  219. X_train,
  220. y_train,
  221. validation_data=(X_test, y_test),
  222. callbacks=[WandbCallback()],
  223. )
  224. ```
  225. `WandbCallback` will automatically log history data from any
  226. metrics collected by keras: loss and anything passed into `keras_model.compile()`.
  227. `WandbCallback` will set summary metrics for the run associated with the "best" training
  228. step, where "best" is defined by the `monitor` and `mode` attributes. This defaults
  229. to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
  230. associated with the best `epoch`.
  231. `WandbCallback` can optionally log gradient and parameter histograms.
  232. `WandbCallback` can optionally save training and validation data for wandb to visualize.
  233. Args:
  234. monitor: (str) name of metric to monitor. Defaults to `val_loss`.
  235. mode: (str) one of {`auto`, `min`, `max`}.
  236. `min` - save model when monitor is minimized
  237. `max` - save model when monitor is maximized
  238. `auto` - try to guess when to save the model (default).
  239. save_model:
  240. True - save a model when monitor beats all previous epochs
  241. False - don't save models
  242. save_graph: (boolean) if True save model graph to wandb (default to True).
  243. save_weights_only: (boolean) if True, then only the model's weights will be
  244. saved (`model.save_weights(filepath)`), else the full model
  245. is saved (`model.save(filepath)`).
  246. log_weights: (boolean) if True save histograms of the model's layer's weights.
  247. log_gradients: (boolean) if True log histograms of the training gradients
  248. training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed
  249. for calculating gradients - this is mandatory if `log_gradients` is `True`.
  250. validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data
  251. for wandb to visualize. If this is set, every epoch, wandb will
  252. make a small number of predictions and save the results for later visualization. In case
  253. you are working with image data, please also set `input_type` and `output_type` in order
  254. to log correctly.
  255. generator: (generator) a generator that returns validation data for wandb to visualize. This
  256. generator should return tuples `(X,y)`. Either `validate_data` or generator should
  257. be set for wandb to visualize specific data examples. In case you are working with image data,
  258. please also set `input_type` and `output_type` in order to log correctly.
  259. validation_steps: (int) if `validation_data` is a generator, how many
  260. steps to run the generator for the full validation set.
  261. labels: (list) If you are visualizing your data with wandb this list of labels
  262. will convert numeric output to understandable string if you are building a
  263. multiclass classifier. If you are making a binary classifier you can pass in
  264. a list of two labels ["label for false", "label for true"]. If `validate_data`
  265. and generator are both false, this won't do anything.
  266. predictions: (int) the number of predictions to make for visualization each epoch, max
  267. is 100.
  268. input_type: (string) type of the model input to help visualization. can be one of:
  269. (`image`, `images`, `segmentation_mask`, `auto`).
  270. output_type: (string) type of the model output to help visualization. can be one of:
  271. (`image`, `images`, `segmentation_mask`, `label`).
  272. log_evaluation: (boolean) if True, save a Table containing validation data and the
  273. model's predictions at each epoch. See `validation_indexes`,
  274. `validation_row_processor`, and `output_row_processor` for additional details.
  275. class_colors: ([float, float, float]) if the input or output is a segmentation mask,
  276. an array containing an rgb tuple (range 0-1) for each class.
  277. log_batch_frequency: (integer) if None, callback will log every epoch.
  278. If set to integer, callback will log training metrics every `log_batch_frequency`
  279. batches.
  280. log_best_prefix: (string) if None, no extra summary metrics will be saved.
  281. If set to a string, the monitored metric and epoch will be prepended with this value
  282. and stored as summary metrics.
  283. validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
  284. with each validation example. If log_evaluation is True and `validation_indexes` is provided,
  285. then a Table of validation data will not be created and instead each prediction will
  286. be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
  287. such keys are is use `Table.get_index()` which will return a list of row keys.
  288. validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
  289. The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
  290. then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
  291. input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
  292. it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
  293. but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
  294. as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
  295. output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
  296. the results of the model output.
  297. infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
  298. should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
  299. processors where appropriate.
  300. log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
  301. Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
  302. compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
  303. """
  304. def __init__(
  305. self,
  306. monitor="val_loss",
  307. verbose=0,
  308. mode="auto",
  309. save_weights_only=False,
  310. log_weights=False,
  311. log_gradients=False,
  312. save_model=True,
  313. training_data=None,
  314. validation_data=None,
  315. labels=None,
  316. predictions=36,
  317. generator=None,
  318. input_type=None,
  319. output_type=None,
  320. log_evaluation=False,
  321. validation_steps=None,
  322. class_colors=None,
  323. log_batch_frequency=None,
  324. log_best_prefix="best_",
  325. save_graph=True,
  326. validation_indexes=None,
  327. validation_row_processor=None,
  328. prediction_row_processor=None,
  329. infer_missing_processors=True,
  330. log_evaluation_frequency=0,
  331. compute_flops=False,
  332. **kwargs,
  333. ):
  334. if wandb.run is None:
  335. raise wandb.Error("You must call wandb.init() before WandbCallback()")
  336. warn_and_record_deprecation(
  337. feature=Deprecated(keras_callback=True),
  338. message=(
  339. "WandbCallback is deprecated and will be removed in a future release. "
  340. "Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
  341. "callbacks instead. "
  342. "See https://docs.wandb.ai/models/integrations/keras for more information."
  343. ),
  344. )
  345. with telemetry.context(run=wandb.run) as tel:
  346. tel.feature.keras = True
  347. self.validation_data = None
  348. # This is kept around for legacy reasons
  349. if validation_data is not None:
  350. if is_generator_like(validation_data):
  351. generator = validation_data
  352. else:
  353. self.validation_data = validation_data
  354. if labels is None:
  355. labels = []
  356. self.labels = labels
  357. self.predictions = min(predictions, 100)
  358. self.monitor = monitor
  359. self.verbose = verbose
  360. self.save_weights_only = save_weights_only
  361. self.save_graph = save_graph
  362. wandb.save("model-best.h5")
  363. self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
  364. self.save_model = save_model
  365. if save_model:
  366. warn_and_record_deprecation(
  367. feature=Deprecated(keras_callback__save_model=True),
  368. message=(
  369. "The save_model argument by default saves the model in the HDF5 format that cannot save "
  370. "custom objects like subclassed models and custom layers. This behavior will be deprecated "
  371. "in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
  372. "as W&B files and the SavedModel as W&B Artifacts."
  373. ),
  374. )
  375. self.save_model_as_artifact = True
  376. self.log_weights = log_weights
  377. self.log_gradients = log_gradients
  378. self.training_data = training_data
  379. self.generator = generator
  380. self._graph_rendered = False
  381. data_type = kwargs.get("data_type")
  382. if data_type is not None:
  383. warn_and_record_deprecation(
  384. feature=Deprecated(keras_callback__data_type=True),
  385. message=(
  386. "The data_type argument of wandb.keras.WandbCallback is deprecated "
  387. "and will be removed in a future release. Please use input_type instead.\n"
  388. "Setting input_type = data_type."
  389. ),
  390. )
  391. input_type = data_type
  392. self.input_type = input_type
  393. self.output_type = output_type
  394. self.log_evaluation = log_evaluation
  395. self.validation_steps = validation_steps
  396. self.class_colors = np.array(class_colors) if class_colors is not None else None
  397. self.log_batch_frequency = log_batch_frequency
  398. self.log_best_prefix = log_best_prefix
  399. self.compute_flops = compute_flops
  400. self._prediction_batch_size = None
  401. if self.log_gradients:
  402. if int(tf.__version__.split(".")[0]) < 2:
  403. raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
  404. if self.training_data is None:
  405. raise ValueError(
  406. "training_data argument is required for gradient logging."
  407. )
  408. if isinstance(self.training_data, (list, tuple)):
  409. if len(self.training_data) != 2:
  410. raise ValueError("training data must be a tuple of length two")
  411. self._training_data_x, self._training_data_y = self.training_data
  412. else:
  413. self._training_data_x = (
  414. self.training_data
  415. ) # generator, tf.data.Dataset etc
  416. self._training_data_y = None
  417. # From Keras
  418. if mode not in ["auto", "min", "max"]:
  419. wandb.termwarn(
  420. f"WandbCallback mode {mode} is unknown, fallback to auto mode."
  421. )
  422. mode = "auto"
  423. if mode == "min":
  424. self.monitor_op = operator.lt
  425. self.best = float("inf")
  426. elif mode == "max":
  427. self.monitor_op = operator.gt
  428. self.best = float("-inf")
  429. else:
  430. if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
  431. self.monitor_op = operator.gt
  432. self.best = float("-inf")
  433. else:
  434. self.monitor_op = operator.lt
  435. self.best = float("inf")
  436. # Get the previous best metric for resumed runs
  437. previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
  438. if previous_best is not None:
  439. self.best = previous_best
  440. self._validation_data_logger = None
  441. self._validation_indexes = validation_indexes
  442. self._validation_row_processor = validation_row_processor
  443. self._prediction_row_processor = prediction_row_processor
  444. self._infer_missing_processors = infer_missing_processors
  445. self._log_evaluation_frequency = log_evaluation_frequency
  446. self._model_trained_since_last_eval = False
  447. def _build_grad_accumulator_model(self):
  448. inputs = self.model.inputs
  449. outputs = self.model(inputs)
  450. grad_acc_model = tf.keras.models.Model(inputs, outputs)
  451. grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
  452. # make sure magic doesn't think this is a user model
  453. grad_acc_model._wandb_internal_model = True
  454. self._grad_accumulator_model = grad_acc_model
  455. self._grad_accumulator_callback = _GradAccumulatorCallback()
  456. def _implements_train_batch_hooks(self):
  457. return self.log_batch_frequency is not None
  458. def _implements_test_batch_hooks(self):
  459. return self.log_batch_frequency is not None
  460. def _implements_predict_batch_hooks(self):
  461. return self.log_batch_frequency is not None
  462. def set_params(self, params):
  463. self.params = params
  464. def set_model(self, model):
  465. super().set_model(model)
  466. if self.input_type == "auto" and len(model.inputs) == 1:
  467. self.input_type = wandb.util.guess_data_type(
  468. model.inputs[0].shape, risky=True
  469. )
  470. if self.input_type and self.output_type is None and len(model.outputs) == 1:
  471. self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
  472. if self.log_gradients:
  473. self._build_grad_accumulator_model()
  474. def _attempt_evaluation_log(self, commit=True):
  475. if self.log_evaluation and self._validation_data_logger:
  476. try:
  477. if not self.model:
  478. wandb.termwarn("WandbCallback unable to read model from trainer")
  479. else:
  480. self._validation_data_logger.log_predictions(
  481. predictions=self._validation_data_logger.make_predictions(
  482. self.model.predict
  483. ),
  484. commit=commit,
  485. )
  486. self._model_trained_since_last_eval = False
  487. except Exception as e:
  488. wandb.termwarn("Error during prediction logging for epoch: " + str(e))
  489. def on_epoch_end(self, epoch, logs=None):
  490. if logs is None:
  491. logs = {}
  492. if self.log_weights:
  493. wandb.log(self._log_weights(), commit=False)
  494. if self.log_gradients:
  495. wandb.log(self._log_gradients(), commit=False)
  496. if self.input_type in (
  497. "image",
  498. "images",
  499. "segmentation_mask",
  500. ) or self.output_type in ("image", "images", "segmentation_mask"):
  501. if self.generator:
  502. self.validation_data = next(self.generator)
  503. if self.validation_data is None:
  504. wandb.termwarn(
  505. "No validation_data set, pass a generator to the callback."
  506. )
  507. elif self.validation_data and len(self.validation_data) > 0:
  508. wandb.log(
  509. {"examples": self._log_images(num_images=self.predictions)},
  510. commit=False,
  511. )
  512. if (
  513. self._log_evaluation_frequency > 0
  514. and epoch % self._log_evaluation_frequency == 0
  515. ):
  516. self._attempt_evaluation_log(commit=False)
  517. wandb.log({"epoch": epoch}, commit=False)
  518. wandb.log(logs, commit=True)
  519. self.current = logs.get(self.monitor)
  520. if self.current and self.monitor_op(self.current, self.best):
  521. if self.log_best_prefix:
  522. wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
  523. self.current
  524. )
  525. wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
  526. if self.verbose and not self.save_model:
  527. wandb.termlog(
  528. f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
  529. )
  530. if self.save_model:
  531. self._save_model(epoch)
  532. if self.save_model and self.save_model_as_artifact:
  533. self._save_model_as_artifact(epoch)
  534. self.best = self.current
  535. # This is what keras used pre tensorflow.keras
  536. def on_batch_begin(self, batch, logs=None):
  537. pass
  538. # This is what keras used pre tensorflow.keras
  539. def on_batch_end(self, batch, logs=None):
  540. if self.save_graph and not self._graph_rendered:
  541. # Couldn't do this in train_begin because keras may still not be built
  542. wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
  543. self._graph_rendered = True
  544. if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
  545. wandb.log(logs, commit=True)
  546. def on_train_batch_begin(self, batch, logs=None):
  547. self._model_trained_since_last_eval = True
  548. def on_train_batch_end(self, batch, logs=None):
  549. if self.save_graph and not self._graph_rendered:
  550. # Couldn't do this in train_begin because keras may still not be built
  551. wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
  552. self._graph_rendered = True
  553. if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
  554. wandb.log(logs, commit=True)
  555. def on_test_begin(self, logs=None):
  556. pass
  557. def on_test_end(self, logs=None):
  558. pass
  559. def on_test_batch_begin(self, batch, logs=None):
  560. pass
  561. def on_test_batch_end(self, batch, logs=None):
  562. pass
  563. def on_train_begin(self, logs=None):
  564. if self.log_evaluation:
  565. try:
  566. validation_data = None
  567. if self.validation_data:
  568. validation_data = self.validation_data
  569. elif self.generator:
  570. if not self.validation_steps:
  571. wandb.termwarn(
  572. "WandbCallback is unable to log validation data. "
  573. "When using a generator for validation_data, you must pass validation_steps"
  574. )
  575. else:
  576. x = None
  577. y_true = None
  578. for _ in range(self.validation_steps):
  579. bx, by_true = next(self.generator)
  580. if x is None:
  581. x, y_true = bx, by_true
  582. else:
  583. x, y_true = (
  584. np.append(x, bx, axis=0),
  585. np.append(y_true, by_true, axis=0),
  586. )
  587. validation_data = (x, y_true)
  588. else:
  589. wandb.termwarn(
  590. "WandbCallback is unable to read validation_data from trainer "
  591. "and therefore cannot log validation data. Ensure Keras is properly "
  592. "patched by calling `from wandb.keras import WandbCallback` at the top of your script."
  593. )
  594. if validation_data:
  595. self._validation_data_logger = ValidationDataLogger(
  596. inputs=validation_data[0],
  597. targets=validation_data[1],
  598. indexes=self._validation_indexes,
  599. validation_row_processor=self._validation_row_processor,
  600. prediction_row_processor=self._prediction_row_processor,
  601. class_labels=self.labels,
  602. infer_missing_processors=self._infer_missing_processors,
  603. )
  604. except Exception as e:
  605. wandb.termwarn(
  606. "Error initializing ValidationDataLogger in WandbCallback. "
  607. f"Skipping logging validation data. Error: {str(e)}"
  608. )
  609. if self.compute_flops and _can_compute_flops():
  610. try:
  611. wandb.summary["GFLOPs"] = self.get_flops()
  612. except Exception:
  613. logger.exception("Error computing FLOPs")
  614. wandb.termwarn("Unable to compute FLOPs for this model.")
  615. def on_train_end(self, logs=None):
  616. if self._model_trained_since_last_eval:
  617. self._attempt_evaluation_log()
  618. def on_predict_begin(self, logs=None):
  619. pass
  620. def on_predict_end(self, logs=None):
  621. pass
  622. def on_predict_batch_begin(self, batch, logs=None):
  623. pass
  624. def on_predict_batch_end(self, batch, logs=None):
  625. pass
  626. def _logits_to_captions(self, logits):
  627. if logits[0].shape[-1] == 1:
  628. # Scalar output from the model
  629. # TODO: handle validation_y
  630. if len(self.labels) == 2:
  631. # User has named true and false
  632. captions = [
  633. self.labels[1] if logits[0] > 0.5 else self.labels[0]
  634. for logit in logits
  635. ]
  636. else:
  637. if len(self.labels) != 0:
  638. wandb.termwarn(
  639. "keras model is producing a single output, "
  640. 'so labels should be a length two array: ["False label", "True label"].'
  641. )
  642. captions = [logit[0] for logit in logits]
  643. else:
  644. # Vector output from the model
  645. # TODO: handle validation_y
  646. labels = np.argmax(np.stack(logits), axis=1)
  647. if len(self.labels) > 0:
  648. # User has named the categories in self.labels
  649. captions = []
  650. for label in labels:
  651. try:
  652. captions.append(self.labels[label])
  653. except IndexError:
  654. captions.append(label)
  655. else:
  656. captions = labels
  657. return captions
  658. def _masks_to_pixels(self, masks):
  659. # if its a binary mask, just return it as grayscale instead of picking the argmax
  660. if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
  661. return masks
  662. class_colors = (
  663. self.class_colors
  664. if self.class_colors is not None
  665. else np.array(wandb.util.class_colors(masks[0].shape[2]))
  666. )
  667. imgs = class_colors[np.argmax(masks, axis=-1)]
  668. return imgs
  669. def _log_images(self, num_images=36):
  670. validation_X = self.validation_data[0] # noqa: N806
  671. validation_y = self.validation_data[1]
  672. validation_length = len(validation_X)
  673. if validation_length > num_images:
  674. # pick some data at random
  675. indices = np.random.choice(validation_length, num_images, replace=False)
  676. else:
  677. indices = range(validation_length)
  678. test_data = []
  679. test_output = []
  680. for i in indices:
  681. test_example = validation_X[i]
  682. test_data.append(test_example)
  683. test_output.append(validation_y[i])
  684. if self.model.stateful:
  685. predictions = self.model.predict(np.stack(test_data), batch_size=1)
  686. self.model.reset_states()
  687. else:
  688. predictions = self.model.predict(
  689. np.stack(test_data), batch_size=self._prediction_batch_size
  690. )
  691. if len(predictions) != len(test_data):
  692. self._prediction_batch_size = 1
  693. predictions = self.model.predict(
  694. np.stack(test_data), batch_size=self._prediction_batch_size
  695. )
  696. if self.input_type == "label":
  697. if self.output_type in ("image", "images", "segmentation_mask"):
  698. captions = self._logits_to_captions(test_data)
  699. output_image_data = (
  700. self._masks_to_pixels(predictions)
  701. if self.output_type == "segmentation_mask"
  702. else predictions
  703. )
  704. reference_image_data = (
  705. self._masks_to_pixels(test_output)
  706. if self.output_type == "segmentation_mask"
  707. else test_output
  708. )
  709. output_images = [
  710. wandb.Image(data, caption=captions[i], grouping=2)
  711. for i, data in enumerate(output_image_data)
  712. ]
  713. reference_images = [
  714. wandb.Image(data, caption=captions[i])
  715. for i, data in enumerate(reference_image_data)
  716. ]
  717. return list(chain.from_iterable(zip(output_images, reference_images)))
  718. elif self.input_type in ("image", "images", "segmentation_mask"):
  719. input_image_data = (
  720. self._masks_to_pixels(test_data)
  721. if self.input_type == "segmentation_mask"
  722. else test_data
  723. )
  724. if self.output_type == "label":
  725. # we just use the predicted label as the caption for now
  726. captions = self._logits_to_captions(predictions)
  727. return [
  728. wandb.Image(data, caption=captions[i])
  729. for i, data in enumerate(test_data)
  730. ]
  731. elif self.output_type in ("image", "images", "segmentation_mask"):
  732. output_image_data = (
  733. self._masks_to_pixels(predictions)
  734. if self.output_type == "segmentation_mask"
  735. else predictions
  736. )
  737. reference_image_data = (
  738. self._masks_to_pixels(test_output)
  739. if self.output_type == "segmentation_mask"
  740. else test_output
  741. )
  742. input_images = [
  743. wandb.Image(data, grouping=3)
  744. for i, data in enumerate(input_image_data)
  745. ]
  746. output_images = [
  747. wandb.Image(data) for i, data in enumerate(output_image_data)
  748. ]
  749. reference_images = [
  750. wandb.Image(data) for i, data in enumerate(reference_image_data)
  751. ]
  752. return list(
  753. chain.from_iterable(
  754. zip(input_images, output_images, reference_images)
  755. )
  756. )
  757. else:
  758. # unknown output, just log the input images
  759. return [wandb.Image(img) for img in test_data]
  760. elif self.output_type in ("image", "images", "segmentation_mask"):
  761. # unknown input, just log the predicted and reference outputs without captions
  762. output_image_data = (
  763. self._masks_to_pixels(predictions)
  764. if self.output_type == "segmentation_mask"
  765. else predictions
  766. )
  767. reference_image_data = (
  768. self._masks_to_pixels(test_output)
  769. if self.output_type == "segmentation_mask"
  770. else test_output
  771. )
  772. output_images = [
  773. wandb.Image(data, grouping=2)
  774. for i, data in enumerate(output_image_data)
  775. ]
  776. reference_images = [
  777. wandb.Image(data) for i, data in enumerate(reference_image_data)
  778. ]
  779. return list(chain.from_iterable(zip(output_images, reference_images)))
  780. def _log_weights(self):
  781. metrics = {}
  782. for layer in self.model.layers:
  783. weights = layer.get_weights()
  784. if len(weights) == 1:
  785. _update_if_numeric(
  786. metrics, "parameters/" + layer.name + ".weights", weights[0]
  787. )
  788. elif len(weights) == 2:
  789. _update_if_numeric(
  790. metrics, "parameters/" + layer.name + ".weights", weights[0]
  791. )
  792. _update_if_numeric(
  793. metrics, "parameters/" + layer.name + ".bias", weights[1]
  794. )
  795. return metrics
  796. def _log_gradients(self):
  797. # Suppress callback warnings grad accumulator
  798. og_level = tf_logger.level
  799. tf_logger.setLevel("ERROR")
  800. self._grad_accumulator_model.fit(
  801. self._training_data_x,
  802. self._training_data_y,
  803. verbose=0,
  804. callbacks=[self._grad_accumulator_callback],
  805. )
  806. tf_logger.setLevel(og_level)
  807. weights = self.model.trainable_weights
  808. grads = self._grad_accumulator_callback.grads
  809. metrics = {}
  810. for weight, grad in zip(weights, grads):
  811. metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
  812. wandb.Histogram(grad)
  813. )
  814. return metrics
  815. def _log_dataframe(self):
  816. x, y_true, y_pred = None, None, None
  817. if self.validation_data:
  818. x, y_true = self.validation_data[0], self.validation_data[1]
  819. y_pred = self.model.predict(x)
  820. elif self.generator:
  821. if not self.validation_steps:
  822. wandb.termwarn(
  823. "when using a generator for validation data with dataframes, "
  824. "you must pass validation_steps. skipping"
  825. )
  826. return None
  827. for _ in range(self.validation_steps):
  828. bx, by_true = next(self.generator)
  829. by_pred = self.model.predict(bx)
  830. if x is None:
  831. x, y_true, y_pred = bx, by_true, by_pred
  832. else:
  833. x, y_true, y_pred = (
  834. np.append(x, bx, axis=0),
  835. np.append(y_true, by_true, axis=0),
  836. np.append(y_pred, by_pred, axis=0),
  837. )
  838. if self.input_type in ("image", "images") and self.output_type == "label":
  839. return wandb.image_categorizer_dataframe(
  840. x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
  841. )
  842. elif (
  843. self.input_type in ("image", "images")
  844. and self.output_type == "segmentation_mask"
  845. ):
  846. return wandb.image_segmentation_dataframe(
  847. x=x,
  848. y_true=y_true,
  849. y_pred=y_pred,
  850. labels=self.labels,
  851. class_colors=self.class_colors,
  852. )
  853. else:
  854. wandb.termwarn(
  855. f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
  856. )
  857. return None
  858. def _save_model(self, epoch):
  859. if wandb.run.disabled:
  860. return
  861. if self.verbose > 0:
  862. wandb.termlog(
  863. f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
  864. f"saving model to {self.filepath}"
  865. )
  866. try:
  867. if self.save_weights_only:
  868. self.model.save_weights(self.filepath, overwrite=True)
  869. else:
  870. self.model.save(self.filepath, overwrite=True)
  871. # Was getting `RuntimeError: Unable to create link` in TF 1.13.1
  872. # also saw `TypeError: can't pickle _thread.RLock objects`
  873. except (ImportError, RuntimeError, TypeError, AttributeError):
  874. logger.exception("Error saving model in the h5py format")
  875. wandb.termerror(
  876. "Can't save model in the h5py format. The model will be saved as "
  877. "as an W&B Artifact in the 'tf' format."
  878. )
  879. def _save_model_as_artifact(self, epoch):
  880. if wandb.run.disabled:
  881. return
  882. # Save the model in the SavedModel format.
  883. # TODO: Replace this manual artifact creation with the `log_model` method
  884. # after `log_model` is released from beta.
  885. self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")
  886. # Log the model as artifact.
  887. name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
  888. model_artifact = wandb.Artifact(name, type="model")
  889. model_artifact.add_dir(self.filepath[:-3])
  890. wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])
  891. # Remove the SavedModel from wandb dir as we don't want to log it to save memory.
  892. shutil.rmtree(self.filepath[:-3])
  893. def get_flops(self) -> float:
  894. """Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.
  895. It uses tf.compat.v1.profiler under the hood.
  896. """
  897. if not hasattr(self, "model"):
  898. raise wandb.Error("self.model must be set before using this method.")
  899. if not isinstance(
  900. self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
  901. ):
  902. raise TypeError(
  903. "Calculating FLOPS is only supported for "
  904. "`tf.keras.Model` and `tf.keras.Sequential` instances."
  905. )
  906. from tensorflow.python.framework.convert_to_constants import (
  907. convert_variables_to_constants_v2_as_graph,
  908. )
  909. # Compute FLOPs for one sample
  910. batch_size = 1
  911. inputs = [
  912. tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
  913. for inp in self.model.inputs
  914. ]
  915. # convert tf.keras model into frozen graph to count FLOPs about operations used at inference
  916. real_model = tf.function(self.model).get_concrete_function(inputs)
  917. frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
  918. # Calculate FLOPs with tf.profiler
  919. run_meta = tf.compat.v1.RunMetadata()
  920. opts = (
  921. tf.compat.v1.profiler.ProfileOptionBuilder(
  922. tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
  923. )
  924. .with_empty_output()
  925. .build()
  926. )
  927. flops = tf.compat.v1.profiler.profile(
  928. graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
  929. )
  930. # convert to GFLOPs
  931. return (flops.total_float_ops / 1e9) / 2