| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084 |
- """keras init."""
- import logging
- import operator
- import os
- import shutil
- import sys
- from itertools import chain
- import numpy as np
- import tensorflow as tf
- import tensorflow.keras.backend as K # noqa: N812
- import wandb
- from wandb.proto.wandb_telemetry_pb2 import Deprecated
- from wandb.sdk.integration_utils.data_logging import ValidationDataLogger
- from wandb.sdk.lib import telemetry
- from wandb.sdk.lib.deprecation import warn_and_record_deprecation
- from wandb.util import add_import_hook
- def _check_keras_version():
- from keras import __version__ as keras_version
- from packaging.version import parse
- if parse(keras_version) < parse("2.4.0"):
- wandb.termwarn(
- f"Keras version {keras_version} is not fully supported. Required keras >= 2.4.0"
- )
- def _can_compute_flops() -> bool:
- """FLOPS computation is restricted to TF 2.x as it requires tf.compat.v1."""
- from packaging.version import parse
- return parse(tf.__version__) >= parse("2.0.0")
- if "keras" in sys.modules:
- _check_keras_version()
- else:
- add_import_hook("keras", _check_keras_version)
- logger = logging.getLogger(__name__)
- def is_dataset(data):
- dataset_ops = wandb.util.get_module("tensorflow.python.data.ops.dataset_ops")
- if dataset_ops and hasattr(dataset_ops, "DatasetV2"):
- dataset_types = (dataset_ops.DatasetV2,)
- if hasattr(dataset_ops, "DatasetV1"):
- dataset_types = dataset_types + (dataset_ops.DatasetV1,)
- return isinstance(data, dataset_types)
- else:
- return False
- def is_generator_like(data):
- # Checks if data is a generator, Sequence, or Iterator.
- types = (tf.keras.utils.Sequence,)
- iterator_ops = wandb.util.get_module("tensorflow.python.data.ops.iterator_ops")
- if iterator_ops:
- types = types + (iterator_ops.Iterator,)
- # EagerIterator was in tensorflow < 2
- if hasattr(iterator_ops, "EagerIterator"):
- types = types + (iterator_ops.EagerIterator,)
- elif hasattr(iterator_ops, "IteratorV2"):
- types = types + (iterator_ops.IteratorV2,)
- return hasattr(data, "next") or hasattr(data, "__next__") or isinstance(data, types)
- def patch_tf_keras(): # noqa: C901
- from packaging.version import parse
- from tensorflow.python.eager import context
- if parse("2.6.0") <= parse(tf.__version__) < parse("2.13.0"):
- keras_engine = "keras.engine"
- try:
- from keras.engine import training
- from keras.engine import training_arrays_v1 as training_arrays
- from keras.engine import training_generator_v1 as training_generator
- except (ImportError, AttributeError):
- wandb.termerror("Unable to patch Tensorflow/Keras")
- logger.exception("exception while trying to patch_tf_keras")
- return
- else:
- keras_engine = "tensorflow.python.keras.engine"
- from tensorflow.python.keras.engine import training
- try:
- from tensorflow.python.keras.engine import (
- training_arrays_v1 as training_arrays,
- )
- from tensorflow.python.keras.engine import (
- training_generator_v1 as training_generator,
- )
- except (ImportError, AttributeError):
- try:
- from tensorflow.python.keras.engine import (
- training_arrays,
- training_generator,
- )
- except (ImportError, AttributeError):
- wandb.termerror("Unable to patch Tensorflow/Keras")
- logger.exception("exception while trying to patch_tf_keras")
- return
- # Tensorflow 2.1
- training_v2_1 = wandb.util.get_module("tensorflow.python.keras.engine.training_v2")
- # Tensorflow 2.2
- training_v2_2 = wandb.util.get_module(f"{keras_engine}.training_v1")
- if training_v2_1:
- old_v2 = training_v2_1.Loop.fit
- elif training_v2_2:
- old_v2 = training.Model.fit
- old_arrays = training_arrays.fit_loop
- old_generator = training_generator.fit_generator
- def set_wandb_attrs(cbk, val_data):
- if isinstance(cbk, WandbCallback):
- if is_generator_like(val_data):
- cbk.generator = val_data
- elif is_dataset(val_data):
- if context.executing_eagerly():
- cbk.generator = iter(val_data)
- else:
- wandb.termwarn(
- "Found a validation dataset in graph mode, can't patch Keras."
- )
- elif isinstance(val_data, tuple) and isinstance(val_data[0], tf.Tensor):
- # Graph mode dataset generator
- def gen():
- while True:
- yield K.get_session().run(val_data)
- cbk.generator = gen()
- else:
- cbk.validation_data = val_data
- def new_arrays(*args, **kwargs):
- cbks = kwargs.get("callbacks", [])
- val_inputs = kwargs.get("val_inputs")
- val_targets = kwargs.get("val_targets")
- # TODO: these could be generators, why index 0?
- if val_inputs and val_targets:
- for cbk in cbks:
- set_wandb_attrs(cbk, (val_inputs[0], val_targets[0]))
- return old_arrays(*args, **kwargs)
- def new_generator(*args, **kwargs):
- cbks = kwargs.get("callbacks", [])
- val_data = kwargs.get("validation_data")
- if val_data:
- for cbk in cbks:
- set_wandb_attrs(cbk, val_data)
- return old_generator(*args, **kwargs)
- def new_v2(*args, **kwargs):
- cbks = kwargs.get("callbacks", [])
- val_data = kwargs.get("validation_data")
- if val_data:
- for cbk in cbks:
- set_wandb_attrs(cbk, val_data)
- return old_v2(*args, **kwargs)
- training_arrays.orig_fit_loop = old_arrays
- training_arrays.fit_loop = new_arrays
- training_generator.orig_fit_generator = old_generator
- training_generator.fit_generator = new_generator
- wandb.patched["keras"].append([f"{keras_engine}.training_arrays", "fit_loop"])
- wandb.patched["keras"].append(
- [f"{keras_engine}.training_generator", "fit_generator"]
- )
- if training_v2_1:
- training_v2_1.Loop.fit = new_v2
- wandb.patched["keras"].append(
- ["tensorflow.python.keras.engine.training_v2.Loop", "fit"]
- )
- elif training_v2_2:
- training.Model.fit = new_v2
- wandb.patched["keras"].append([f"{keras_engine}.training.Model", "fit"])
- def _array_has_dtype(array):
- return hasattr(array, "dtype")
- def _update_if_numeric(metrics, key, values):
- if not _array_has_dtype(values):
- _warn_not_logging(key)
- return
- if not is_numeric_array(values):
- _warn_not_logging_non_numeric(key)
- return
- metrics[key] = wandb.Histogram(values)
- def is_numeric_array(array):
- return np.issubdtype(array.dtype, np.number)
- def _warn_not_logging_non_numeric(name):
- wandb.termwarn(
- f"Non-numeric values found in layer: {name}, not logging this layer",
- repeat=False,
- )
- def _warn_not_logging(name):
- wandb.termwarn(
- f"Layer {name} has undetermined datatype not logging this layer",
- repeat=False,
- )
- tf_logger = tf.get_logger()
- patch_tf_keras()
- ### For gradient logging ###
- def _get_custom_optimizer_parent_class():
- from packaging.version import parse
- if parse(tf.__version__) >= parse("2.9.0"):
- custom_optimizer_parent_class = tf.keras.optimizers.legacy.Optimizer
- else:
- custom_optimizer_parent_class = tf.keras.optimizers.Optimizer
- return custom_optimizer_parent_class
- _custom_optimizer_parent_class = _get_custom_optimizer_parent_class()
- class _CustomOptimizer(_custom_optimizer_parent_class):
- def __init__(self):
- super().__init__(name="CustomOptimizer")
- self._resource_apply_dense = tf.function(self._resource_apply_dense)
- self._resource_apply_sparse = tf.function(self._resource_apply_sparse)
- def _resource_apply_dense(self, grad, var):
- var.assign(grad)
- # this needs to be implemented to prevent a NotImplementedError when
- # using Lookup layers.
- def _resource_apply_sparse(self, grad, var, indices):
- pass
- def get_config(self):
- return super().get_config()
- class _GradAccumulatorCallback(tf.keras.callbacks.Callback):
- """Accumulates gradients during a fit() call when used in conjunction with the CustomOptimizer above."""
- def set_model(self, model):
- super().set_model(model)
- self.og_weights = model.get_weights()
- self.grads = [np.zeros(tuple(w.shape)) for w in model.trainable_weights]
- def on_batch_end(self, batch, logs=None):
- for g, w in zip(self.grads, self.model.trainable_weights):
- g += w.numpy()
- self.model.set_weights(self.og_weights)
- def get_grads(self):
- return [g.copy() for g in self.grads]
- ###
- class WandbCallback(tf.keras.callbacks.Callback):
- """`WandbCallback` automatically integrates keras with wandb.
- Example:
- ```python
- model.fit(
- X_train,
- y_train,
- validation_data=(X_test, y_test),
- callbacks=[WandbCallback()],
- )
- ```
- `WandbCallback` will automatically log history data from any
- metrics collected by keras: loss and anything passed into `keras_model.compile()`.
- `WandbCallback` will set summary metrics for the run associated with the "best" training
- step, where "best" is defined by the `monitor` and `mode` attributes. This defaults
- to the epoch with the minimum `val_loss`. `WandbCallback` will by default save the model
- associated with the best `epoch`.
- `WandbCallback` can optionally log gradient and parameter histograms.
- `WandbCallback` can optionally save training and validation data for wandb to visualize.
- Args:
- monitor: (str) name of metric to monitor. Defaults to `val_loss`.
- mode: (str) one of {`auto`, `min`, `max`}.
- `min` - save model when monitor is minimized
- `max` - save model when monitor is maximized
- `auto` - try to guess when to save the model (default).
- save_model:
- True - save a model when monitor beats all previous epochs
- False - don't save models
- save_graph: (boolean) if True save model graph to wandb (default to True).
- save_weights_only: (boolean) if True, then only the model's weights will be
- saved (`model.save_weights(filepath)`), else the full model
- is saved (`model.save(filepath)`).
- log_weights: (boolean) if True save histograms of the model's layer's weights.
- log_gradients: (boolean) if True log histograms of the training gradients
- training_data: (tuple) Same format `(X,y)` as passed to `model.fit`. This is needed
- for calculating gradients - this is mandatory if `log_gradients` is `True`.
- validation_data: (tuple) Same format `(X,y)` as passed to `model.fit`. A set of data
- for wandb to visualize. If this is set, every epoch, wandb will
- make a small number of predictions and save the results for later visualization. In case
- you are working with image data, please also set `input_type` and `output_type` in order
- to log correctly.
- generator: (generator) a generator that returns validation data for wandb to visualize. This
- generator should return tuples `(X,y)`. Either `validate_data` or generator should
- be set for wandb to visualize specific data examples. In case you are working with image data,
- please also set `input_type` and `output_type` in order to log correctly.
- validation_steps: (int) if `validation_data` is a generator, how many
- steps to run the generator for the full validation set.
- labels: (list) If you are visualizing your data with wandb this list of labels
- will convert numeric output to understandable string if you are building a
- multiclass classifier. If you are making a binary classifier you can pass in
- a list of two labels ["label for false", "label for true"]. If `validate_data`
- and generator are both false, this won't do anything.
- predictions: (int) the number of predictions to make for visualization each epoch, max
- is 100.
- input_type: (string) type of the model input to help visualization. can be one of:
- (`image`, `images`, `segmentation_mask`, `auto`).
- output_type: (string) type of the model output to help visualization. can be one of:
- (`image`, `images`, `segmentation_mask`, `label`).
- log_evaluation: (boolean) if True, save a Table containing validation data and the
- model's predictions at each epoch. See `validation_indexes`,
- `validation_row_processor`, and `output_row_processor` for additional details.
- class_colors: ([float, float, float]) if the input or output is a segmentation mask,
- an array containing an rgb tuple (range 0-1) for each class.
- log_batch_frequency: (integer) if None, callback will log every epoch.
- If set to integer, callback will log training metrics every `log_batch_frequency`
- batches.
- log_best_prefix: (string) if None, no extra summary metrics will be saved.
- If set to a string, the monitored metric and epoch will be prepended with this value
- and stored as summary metrics.
- validation_indexes: ([wandb.data_types._TableLinkMixin]) an ordered list of index keys to associate
- with each validation example. If log_evaluation is True and `validation_indexes` is provided,
- then a Table of validation data will not be created and instead each prediction will
- be associated with the row represented by the `TableLinkMixin`. The most common way to obtain
- such keys are is use `Table.get_index()` which will return a list of row keys.
- validation_row_processor: (Callable) a function to apply to the validation data, commonly used to visualize the data.
- The function will receive an `ndx` (int) and a `row` (dict). If your model has a single input,
- then `row["input"]` will be the input data for the row. Else, it will be keyed based on the name of the
- input slot. If your fit function takes a single target, then `row["target"]` will be the target data for the row. Else,
- it will be keyed based on the name of the output slots. For example, if your input data is a single ndarray,
- but you wish to visualize the data as an Image, then you can provide `lambda ndx, row: {"img": wandb.Image(row["input"])}`
- as the processor. Ignored if log_evaluation is False or `validation_indexes` are present.
- output_row_processor: (Callable) same as `validation_row_processor`, but applied to the model's output. `row["output"]` will contain
- the results of the model output.
- infer_missing_processors: (bool) Determines if `validation_row_processor` and `output_row_processor`
- should be inferred if missing. Defaults to True. If `labels` are provided, we will attempt to infer classification-type
- processors where appropriate.
- log_evaluation_frequency: (int) Determines the frequency which evaluation results will be logged. Default 0 (only at the end of training).
- Set to 1 to log every epoch, 2 to log every other epoch, and so on. Has no effect when log_evaluation is False.
- compute_flops: (bool) Compute the FLOPs of your Keras Sequential or Functional model in GigaFLOPs unit.
- """
- def __init__(
- self,
- monitor="val_loss",
- verbose=0,
- mode="auto",
- save_weights_only=False,
- log_weights=False,
- log_gradients=False,
- save_model=True,
- training_data=None,
- validation_data=None,
- labels=None,
- predictions=36,
- generator=None,
- input_type=None,
- output_type=None,
- log_evaluation=False,
- validation_steps=None,
- class_colors=None,
- log_batch_frequency=None,
- log_best_prefix="best_",
- save_graph=True,
- validation_indexes=None,
- validation_row_processor=None,
- prediction_row_processor=None,
- infer_missing_processors=True,
- log_evaluation_frequency=0,
- compute_flops=False,
- **kwargs,
- ):
- if wandb.run is None:
- raise wandb.Error("You must call wandb.init() before WandbCallback()")
- warn_and_record_deprecation(
- feature=Deprecated(keras_callback=True),
- message=(
- "WandbCallback is deprecated and will be removed in a future release. "
- "Please use the WandbMetricsLogger, WandbModelCheckpoint, and WandbEvalCallback "
- "callbacks instead. "
- "See https://docs.wandb.ai/models/integrations/keras for more information."
- ),
- )
- with telemetry.context(run=wandb.run) as tel:
- tel.feature.keras = True
- self.validation_data = None
- # This is kept around for legacy reasons
- if validation_data is not None:
- if is_generator_like(validation_data):
- generator = validation_data
- else:
- self.validation_data = validation_data
- if labels is None:
- labels = []
- self.labels = labels
- self.predictions = min(predictions, 100)
- self.monitor = monitor
- self.verbose = verbose
- self.save_weights_only = save_weights_only
- self.save_graph = save_graph
- wandb.save("model-best.h5")
- self.filepath = os.path.join(wandb.run.dir, "model-best.h5")
- self.save_model = save_model
- if save_model:
- warn_and_record_deprecation(
- feature=Deprecated(keras_callback__save_model=True),
- message=(
- "The save_model argument by default saves the model in the HDF5 format that cannot save "
- "custom objects like subclassed models and custom layers. This behavior will be deprecated "
- "in a future release in favor of the SavedModel format. Meanwhile, the HDF5 model is saved "
- "as W&B files and the SavedModel as W&B Artifacts."
- ),
- )
- self.save_model_as_artifact = True
- self.log_weights = log_weights
- self.log_gradients = log_gradients
- self.training_data = training_data
- self.generator = generator
- self._graph_rendered = False
- data_type = kwargs.get("data_type")
- if data_type is not None:
- warn_and_record_deprecation(
- feature=Deprecated(keras_callback__data_type=True),
- message=(
- "The data_type argument of wandb.keras.WandbCallback is deprecated "
- "and will be removed in a future release. Please use input_type instead.\n"
- "Setting input_type = data_type."
- ),
- )
- input_type = data_type
- self.input_type = input_type
- self.output_type = output_type
- self.log_evaluation = log_evaluation
- self.validation_steps = validation_steps
- self.class_colors = np.array(class_colors) if class_colors is not None else None
- self.log_batch_frequency = log_batch_frequency
- self.log_best_prefix = log_best_prefix
- self.compute_flops = compute_flops
- self._prediction_batch_size = None
- if self.log_gradients:
- if int(tf.__version__.split(".")[0]) < 2:
- raise Exception("Gradient logging requires tensorflow 2.0 or higher.")
- if self.training_data is None:
- raise ValueError(
- "training_data argument is required for gradient logging."
- )
- if isinstance(self.training_data, (list, tuple)):
- if len(self.training_data) != 2:
- raise ValueError("training data must be a tuple of length two")
- self._training_data_x, self._training_data_y = self.training_data
- else:
- self._training_data_x = (
- self.training_data
- ) # generator, tf.data.Dataset etc
- self._training_data_y = None
- # From Keras
- if mode not in ["auto", "min", "max"]:
- wandb.termwarn(
- f"WandbCallback mode {mode} is unknown, fallback to auto mode."
- )
- mode = "auto"
- if mode == "min":
- self.monitor_op = operator.lt
- self.best = float("inf")
- elif mode == "max":
- self.monitor_op = operator.gt
- self.best = float("-inf")
- else:
- if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
- self.monitor_op = operator.gt
- self.best = float("-inf")
- else:
- self.monitor_op = operator.lt
- self.best = float("inf")
- # Get the previous best metric for resumed runs
- previous_best = wandb.run.summary.get(f"{self.log_best_prefix}{self.monitor}")
- if previous_best is not None:
- self.best = previous_best
- self._validation_data_logger = None
- self._validation_indexes = validation_indexes
- self._validation_row_processor = validation_row_processor
- self._prediction_row_processor = prediction_row_processor
- self._infer_missing_processors = infer_missing_processors
- self._log_evaluation_frequency = log_evaluation_frequency
- self._model_trained_since_last_eval = False
- def _build_grad_accumulator_model(self):
- inputs = self.model.inputs
- outputs = self.model(inputs)
- grad_acc_model = tf.keras.models.Model(inputs, outputs)
- grad_acc_model.compile(loss=self.model.loss, optimizer=_CustomOptimizer())
- # make sure magic doesn't think this is a user model
- grad_acc_model._wandb_internal_model = True
- self._grad_accumulator_model = grad_acc_model
- self._grad_accumulator_callback = _GradAccumulatorCallback()
- def _implements_train_batch_hooks(self):
- return self.log_batch_frequency is not None
- def _implements_test_batch_hooks(self):
- return self.log_batch_frequency is not None
- def _implements_predict_batch_hooks(self):
- return self.log_batch_frequency is not None
- def set_params(self, params):
- self.params = params
- def set_model(self, model):
- super().set_model(model)
- if self.input_type == "auto" and len(model.inputs) == 1:
- self.input_type = wandb.util.guess_data_type(
- model.inputs[0].shape, risky=True
- )
- if self.input_type and self.output_type is None and len(model.outputs) == 1:
- self.output_type = wandb.util.guess_data_type(model.outputs[0].shape)
- if self.log_gradients:
- self._build_grad_accumulator_model()
- def _attempt_evaluation_log(self, commit=True):
- if self.log_evaluation and self._validation_data_logger:
- try:
- if not self.model:
- wandb.termwarn("WandbCallback unable to read model from trainer")
- else:
- self._validation_data_logger.log_predictions(
- predictions=self._validation_data_logger.make_predictions(
- self.model.predict
- ),
- commit=commit,
- )
- self._model_trained_since_last_eval = False
- except Exception as e:
- wandb.termwarn("Error during prediction logging for epoch: " + str(e))
- def on_epoch_end(self, epoch, logs=None):
- if logs is None:
- logs = {}
- if self.log_weights:
- wandb.log(self._log_weights(), commit=False)
- if self.log_gradients:
- wandb.log(self._log_gradients(), commit=False)
- if self.input_type in (
- "image",
- "images",
- "segmentation_mask",
- ) or self.output_type in ("image", "images", "segmentation_mask"):
- if self.generator:
- self.validation_data = next(self.generator)
- if self.validation_data is None:
- wandb.termwarn(
- "No validation_data set, pass a generator to the callback."
- )
- elif self.validation_data and len(self.validation_data) > 0:
- wandb.log(
- {"examples": self._log_images(num_images=self.predictions)},
- commit=False,
- )
- if (
- self._log_evaluation_frequency > 0
- and epoch % self._log_evaluation_frequency == 0
- ):
- self._attempt_evaluation_log(commit=False)
- wandb.log({"epoch": epoch}, commit=False)
- wandb.log(logs, commit=True)
- self.current = logs.get(self.monitor)
- if self.current and self.monitor_op(self.current, self.best):
- if self.log_best_prefix:
- wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
- self.current
- )
- wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
- if self.verbose and not self.save_model:
- wandb.termlog(
- f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}"
- )
- if self.save_model:
- self._save_model(epoch)
- if self.save_model and self.save_model_as_artifact:
- self._save_model_as_artifact(epoch)
- self.best = self.current
- # This is what keras used pre tensorflow.keras
- def on_batch_begin(self, batch, logs=None):
- pass
- # This is what keras used pre tensorflow.keras
- def on_batch_end(self, batch, logs=None):
- if self.save_graph and not self._graph_rendered:
- # Couldn't do this in train_begin because keras may still not be built
- wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
- self._graph_rendered = True
- if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
- wandb.log(logs, commit=True)
- def on_train_batch_begin(self, batch, logs=None):
- self._model_trained_since_last_eval = True
- def on_train_batch_end(self, batch, logs=None):
- if self.save_graph and not self._graph_rendered:
- # Couldn't do this in train_begin because keras may still not be built
- wandb.run.summary["graph"] = wandb.Graph.from_keras(self.model)
- self._graph_rendered = True
- if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
- wandb.log(logs, commit=True)
- def on_test_begin(self, logs=None):
- pass
- def on_test_end(self, logs=None):
- pass
- def on_test_batch_begin(self, batch, logs=None):
- pass
- def on_test_batch_end(self, batch, logs=None):
- pass
- def on_train_begin(self, logs=None):
- if self.log_evaluation:
- try:
- validation_data = None
- if self.validation_data:
- validation_data = self.validation_data
- elif self.generator:
- if not self.validation_steps:
- wandb.termwarn(
- "WandbCallback is unable to log validation data. "
- "When using a generator for validation_data, you must pass validation_steps"
- )
- else:
- x = None
- y_true = None
- for _ in range(self.validation_steps):
- bx, by_true = next(self.generator)
- if x is None:
- x, y_true = bx, by_true
- else:
- x, y_true = (
- np.append(x, bx, axis=0),
- np.append(y_true, by_true, axis=0),
- )
- validation_data = (x, y_true)
- else:
- wandb.termwarn(
- "WandbCallback is unable to read validation_data from trainer "
- "and therefore cannot log validation data. Ensure Keras is properly "
- "patched by calling `from wandb.keras import WandbCallback` at the top of your script."
- )
- if validation_data:
- self._validation_data_logger = ValidationDataLogger(
- inputs=validation_data[0],
- targets=validation_data[1],
- indexes=self._validation_indexes,
- validation_row_processor=self._validation_row_processor,
- prediction_row_processor=self._prediction_row_processor,
- class_labels=self.labels,
- infer_missing_processors=self._infer_missing_processors,
- )
- except Exception as e:
- wandb.termwarn(
- "Error initializing ValidationDataLogger in WandbCallback. "
- f"Skipping logging validation data. Error: {str(e)}"
- )
- if self.compute_flops and _can_compute_flops():
- try:
- wandb.summary["GFLOPs"] = self.get_flops()
- except Exception:
- logger.exception("Error computing FLOPs")
- wandb.termwarn("Unable to compute FLOPs for this model.")
- def on_train_end(self, logs=None):
- if self._model_trained_since_last_eval:
- self._attempt_evaluation_log()
- def on_predict_begin(self, logs=None):
- pass
- def on_predict_end(self, logs=None):
- pass
- def on_predict_batch_begin(self, batch, logs=None):
- pass
- def on_predict_batch_end(self, batch, logs=None):
- pass
- def _logits_to_captions(self, logits):
- if logits[0].shape[-1] == 1:
- # Scalar output from the model
- # TODO: handle validation_y
- if len(self.labels) == 2:
- # User has named true and false
- captions = [
- self.labels[1] if logits[0] > 0.5 else self.labels[0]
- for logit in logits
- ]
- else:
- if len(self.labels) != 0:
- wandb.termwarn(
- "keras model is producing a single output, "
- 'so labels should be a length two array: ["False label", "True label"].'
- )
- captions = [logit[0] for logit in logits]
- else:
- # Vector output from the model
- # TODO: handle validation_y
- labels = np.argmax(np.stack(logits), axis=1)
- if len(self.labels) > 0:
- # User has named the categories in self.labels
- captions = []
- for label in labels:
- try:
- captions.append(self.labels[label])
- except IndexError:
- captions.append(label)
- else:
- captions = labels
- return captions
- def _masks_to_pixels(self, masks):
- # if its a binary mask, just return it as grayscale instead of picking the argmax
- if len(masks[0].shape) == 2 or masks[0].shape[-1] == 1:
- return masks
- class_colors = (
- self.class_colors
- if self.class_colors is not None
- else np.array(wandb.util.class_colors(masks[0].shape[2]))
- )
- imgs = class_colors[np.argmax(masks, axis=-1)]
- return imgs
- def _log_images(self, num_images=36):
- validation_X = self.validation_data[0] # noqa: N806
- validation_y = self.validation_data[1]
- validation_length = len(validation_X)
- if validation_length > num_images:
- # pick some data at random
- indices = np.random.choice(validation_length, num_images, replace=False)
- else:
- indices = range(validation_length)
- test_data = []
- test_output = []
- for i in indices:
- test_example = validation_X[i]
- test_data.append(test_example)
- test_output.append(validation_y[i])
- if self.model.stateful:
- predictions = self.model.predict(np.stack(test_data), batch_size=1)
- self.model.reset_states()
- else:
- predictions = self.model.predict(
- np.stack(test_data), batch_size=self._prediction_batch_size
- )
- if len(predictions) != len(test_data):
- self._prediction_batch_size = 1
- predictions = self.model.predict(
- np.stack(test_data), batch_size=self._prediction_batch_size
- )
- if self.input_type == "label":
- if self.output_type in ("image", "images", "segmentation_mask"):
- captions = self._logits_to_captions(test_data)
- output_image_data = (
- self._masks_to_pixels(predictions)
- if self.output_type == "segmentation_mask"
- else predictions
- )
- reference_image_data = (
- self._masks_to_pixels(test_output)
- if self.output_type == "segmentation_mask"
- else test_output
- )
- output_images = [
- wandb.Image(data, caption=captions[i], grouping=2)
- for i, data in enumerate(output_image_data)
- ]
- reference_images = [
- wandb.Image(data, caption=captions[i])
- for i, data in enumerate(reference_image_data)
- ]
- return list(chain.from_iterable(zip(output_images, reference_images)))
- elif self.input_type in ("image", "images", "segmentation_mask"):
- input_image_data = (
- self._masks_to_pixels(test_data)
- if self.input_type == "segmentation_mask"
- else test_data
- )
- if self.output_type == "label":
- # we just use the predicted label as the caption for now
- captions = self._logits_to_captions(predictions)
- return [
- wandb.Image(data, caption=captions[i])
- for i, data in enumerate(test_data)
- ]
- elif self.output_type in ("image", "images", "segmentation_mask"):
- output_image_data = (
- self._masks_to_pixels(predictions)
- if self.output_type == "segmentation_mask"
- else predictions
- )
- reference_image_data = (
- self._masks_to_pixels(test_output)
- if self.output_type == "segmentation_mask"
- else test_output
- )
- input_images = [
- wandb.Image(data, grouping=3)
- for i, data in enumerate(input_image_data)
- ]
- output_images = [
- wandb.Image(data) for i, data in enumerate(output_image_data)
- ]
- reference_images = [
- wandb.Image(data) for i, data in enumerate(reference_image_data)
- ]
- return list(
- chain.from_iterable(
- zip(input_images, output_images, reference_images)
- )
- )
- else:
- # unknown output, just log the input images
- return [wandb.Image(img) for img in test_data]
- elif self.output_type in ("image", "images", "segmentation_mask"):
- # unknown input, just log the predicted and reference outputs without captions
- output_image_data = (
- self._masks_to_pixels(predictions)
- if self.output_type == "segmentation_mask"
- else predictions
- )
- reference_image_data = (
- self._masks_to_pixels(test_output)
- if self.output_type == "segmentation_mask"
- else test_output
- )
- output_images = [
- wandb.Image(data, grouping=2)
- for i, data in enumerate(output_image_data)
- ]
- reference_images = [
- wandb.Image(data) for i, data in enumerate(reference_image_data)
- ]
- return list(chain.from_iterable(zip(output_images, reference_images)))
- def _log_weights(self):
- metrics = {}
- for layer in self.model.layers:
- weights = layer.get_weights()
- if len(weights) == 1:
- _update_if_numeric(
- metrics, "parameters/" + layer.name + ".weights", weights[0]
- )
- elif len(weights) == 2:
- _update_if_numeric(
- metrics, "parameters/" + layer.name + ".weights", weights[0]
- )
- _update_if_numeric(
- metrics, "parameters/" + layer.name + ".bias", weights[1]
- )
- return metrics
- def _log_gradients(self):
- # Suppress callback warnings grad accumulator
- og_level = tf_logger.level
- tf_logger.setLevel("ERROR")
- self._grad_accumulator_model.fit(
- self._training_data_x,
- self._training_data_y,
- verbose=0,
- callbacks=[self._grad_accumulator_callback],
- )
- tf_logger.setLevel(og_level)
- weights = self.model.trainable_weights
- grads = self._grad_accumulator_callback.grads
- metrics = {}
- for weight, grad in zip(weights, grads):
- metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
- wandb.Histogram(grad)
- )
- return metrics
- def _log_dataframe(self):
- x, y_true, y_pred = None, None, None
- if self.validation_data:
- x, y_true = self.validation_data[0], self.validation_data[1]
- y_pred = self.model.predict(x)
- elif self.generator:
- if not self.validation_steps:
- wandb.termwarn(
- "when using a generator for validation data with dataframes, "
- "you must pass validation_steps. skipping"
- )
- return None
- for _ in range(self.validation_steps):
- bx, by_true = next(self.generator)
- by_pred = self.model.predict(bx)
- if x is None:
- x, y_true, y_pred = bx, by_true, by_pred
- else:
- x, y_true, y_pred = (
- np.append(x, bx, axis=0),
- np.append(y_true, by_true, axis=0),
- np.append(y_pred, by_pred, axis=0),
- )
- if self.input_type in ("image", "images") and self.output_type == "label":
- return wandb.image_categorizer_dataframe(
- x=x, y_true=y_true, y_pred=y_pred, labels=self.labels
- )
- elif (
- self.input_type in ("image", "images")
- and self.output_type == "segmentation_mask"
- ):
- return wandb.image_segmentation_dataframe(
- x=x,
- y_true=y_true,
- y_pred=y_pred,
- labels=self.labels,
- class_colors=self.class_colors,
- )
- else:
- wandb.termwarn(
- f"unknown dataframe type for input_type={self.input_type} and output_type={self.output_type}"
- )
- return None
- def _save_model(self, epoch):
- if wandb.run.disabled:
- return
- if self.verbose > 0:
- wandb.termlog(
- f"Epoch {epoch:05d}: {self.monitor} improved from {self.best:.5f} to {self.current:.5f}, "
- f"saving model to {self.filepath}"
- )
- try:
- if self.save_weights_only:
- self.model.save_weights(self.filepath, overwrite=True)
- else:
- self.model.save(self.filepath, overwrite=True)
- # Was getting `RuntimeError: Unable to create link` in TF 1.13.1
- # also saw `TypeError: can't pickle _thread.RLock objects`
- except (ImportError, RuntimeError, TypeError, AttributeError):
- logger.exception("Error saving model in the h5py format")
- wandb.termerror(
- "Can't save model in the h5py format. The model will be saved as "
- "as an W&B Artifact in the 'tf' format."
- )
- def _save_model_as_artifact(self, epoch):
- if wandb.run.disabled:
- return
- # Save the model in the SavedModel format.
- # TODO: Replace this manual artifact creation with the `log_model` method
- # after `log_model` is released from beta.
- self.model.save(self.filepath[:-3], overwrite=True, save_format="tf")
- # Log the model as artifact.
- name = wandb.util.make_artifact_name_safe(f"model-{wandb.run.name}")
- model_artifact = wandb.Artifact(name, type="model")
- model_artifact.add_dir(self.filepath[:-3])
- wandb.run.log_artifact(model_artifact, aliases=["latest", f"epoch_{epoch}"])
- # Remove the SavedModel from wandb dir as we don't want to log it to save memory.
- shutil.rmtree(self.filepath[:-3])
- def get_flops(self) -> float:
- """Calculate FLOPS [GFLOPs] for a tf.keras.Model or tf.keras.Sequential model in inference mode.
- It uses tf.compat.v1.profiler under the hood.
- """
- if not hasattr(self, "model"):
- raise wandb.Error("self.model must be set before using this method.")
- if not isinstance(
- self.model, (tf.keras.models.Sequential, tf.keras.models.Model)
- ):
- raise TypeError(
- "Calculating FLOPS is only supported for "
- "`tf.keras.Model` and `tf.keras.Sequential` instances."
- )
- from tensorflow.python.framework.convert_to_constants import (
- convert_variables_to_constants_v2_as_graph,
- )
- # Compute FLOPs for one sample
- batch_size = 1
- inputs = [
- tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype)
- for inp in self.model.inputs
- ]
- # convert tf.keras model into frozen graph to count FLOPs about operations used at inference
- real_model = tf.function(self.model).get_concrete_function(inputs)
- frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
- # Calculate FLOPs with tf.profiler
- run_meta = tf.compat.v1.RunMetadata()
- opts = (
- tf.compat.v1.profiler.ProfileOptionBuilder(
- tf.compat.v1.profiler.ProfileOptionBuilder().float_operation()
- )
- .with_empty_output()
- .build()
- )
- flops = tf.compat.v1.profiler.profile(
- graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
- )
- # convert to GFLOPs
- return (flops.total_float_ops / 1e9) / 2
|