# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any, Callable, Dict, Optional, Tuple from torch.optim import Optimizer, lr_scheduler from torch.utils.data import DataLoader from kornia.core import Module, Tensor, stack from kornia.metrics import accuracy, mean_average_precision, mean_iou from .trainer import Trainer from .utils import Configuration class ImageClassifierTrainer(Trainer): """Module to be used for image classification purposes. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the :py:func:`~kornia.x.Trainer.evaluate` function implementing a standard :py:func:`~kornia.metrics.accuracy` topk@[1, 5]. .. seealso:: Learn how to use this class in the following `example `__. """ def compute_metrics(self, *args: Tensor) -> Dict[str, float]: if len(args) != 2: raise AssertionError out, target = args acc1, acc5 = accuracy(out, target, topk=(1, 5)) return {"top1": acc1.item(), "top5": acc5.item()} class SemanticSegmentationTrainer(Trainer): """Module to be used for semantic segmentation purposes. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the :py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`. .. seealso:: Learn how to use this class in the following `example `__. """ def compute_metrics(self, *args: Tensor) -> Dict[str, float]: if len(args) != 2: raise AssertionError out, target = args iou = mean_iou(out.argmax(1), target, out.shape[1]).mean() return {"iou": iou.item()} class ObjectDetectionTrainer(Trainer): """Module to be used for object detection purposes. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the :py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`. .. seealso:: Learn how to use this class in the following `example `__. """ def __init__( self, model: Module, train_dataloader: DataLoader[Any], valid_dataloader: DataLoader[Any], criterion: Optional[Module], optimizer: Optimizer, scheduler: lr_scheduler._LRScheduler, config: Configuration, num_classes: int, callbacks: Optional[Dict[str, Callable[..., None]]] = None, loss_computed_by_model: Optional[bool] = None, ) -> None: if callbacks is None: callbacks = {} super().__init__(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks) # TODO: auto-detect if the model is from TorchVision self.loss_computed_by_model = loss_computed_by_model self.num_classes = num_classes def on_model(self, model: Module, sample: Dict[str, Tensor]) -> Tensor: if self.loss_computed_by_model and model.training: return model(sample["input"], sample["target"]) return model(sample["input"]) def compute_loss(self, *args: Tensor) -> Tensor: if self.loss_computed_by_model: # Note: in case of dict losses obtained if isinstance(args[0], dict): return stack([v for _, v in args[0].items()]).mean() return stack(list(args[0])).sum() if self.criterion is None: raise RuntimeError("`criterion` should not be None if `loss_computed_by_model` is False.") return self.criterion(*args) def compute_metrics(self, *args: Tuple[Dict[str, Tensor]]) -> Dict[str, float]: if ( isinstance(args[0], dict) and "boxes" in args[0] and "labels" in args[0] and "scores" in args[0] and isinstance(args[1], dict) and "boxes" in args[1] and "labels" in args[1] ): mAP, _ = mean_average_precision( [a["boxes"] for a in args[0]], [a["labels"] for a in args[0]], [a["scores"] for a in args[0]], [a["boxes"] for a in args[1]], [a["labels"] for a in args[1]], n_classes=self.num_classes, threshold=0.000001, ) return {"mAP": mAP.item()} return super().compute_metrics(*args)