# LICENSE HEADER MANAGED BY add-license-header
#
# Copyright 2018 Kornia Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Callable, Dict, Optional, Tuple
from torch.optim import Optimizer, lr_scheduler
from torch.utils.data import DataLoader
from kornia.core import Module, Tensor, stack
from kornia.metrics import accuracy, mean_average_precision, mean_iou
from .trainer import Trainer
from .utils import Configuration
class ImageClassifierTrainer(Trainer):
"""Module to be used for image classification purposes.
The module subclasses :py:class:`~kornia.x.Trainer` and overrides the
:py:func:`~kornia.x.Trainer.evaluate` function implementing a standard
:py:func:`~kornia.metrics.accuracy` topk@[1, 5].
.. seealso::
Learn how to use this class in the following
`example `__.
"""
def compute_metrics(self, *args: Tensor) -> Dict[str, float]:
if len(args) != 2:
raise AssertionError
out, target = args
acc1, acc5 = accuracy(out, target, topk=(1, 5))
return {"top1": acc1.item(), "top5": acc5.item()}
class SemanticSegmentationTrainer(Trainer):
"""Module to be used for semantic segmentation purposes.
The module subclasses :py:class:`~kornia.x.Trainer` and overrides the
:py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`.
.. seealso::
Learn how to use this class in the following
`example `__.
"""
def compute_metrics(self, *args: Tensor) -> Dict[str, float]:
if len(args) != 2:
raise AssertionError
out, target = args
iou = mean_iou(out.argmax(1), target, out.shape[1]).mean()
return {"iou": iou.item()}
class ObjectDetectionTrainer(Trainer):
"""Module to be used for object detection purposes.
The module subclasses :py:class:`~kornia.x.Trainer` and overrides the
:py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`.
.. seealso::
Learn how to use this class in the following
`example `__.
"""
def __init__(
self,
model: Module,
train_dataloader: DataLoader[Any],
valid_dataloader: DataLoader[Any],
criterion: Optional[Module],
optimizer: Optimizer,
scheduler: lr_scheduler._LRScheduler,
config: Configuration,
num_classes: int,
callbacks: Optional[Dict[str, Callable[..., None]]] = None,
loss_computed_by_model: Optional[bool] = None,
) -> None:
if callbacks is None:
callbacks = {}
super().__init__(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks)
# TODO: auto-detect if the model is from TorchVision
self.loss_computed_by_model = loss_computed_by_model
self.num_classes = num_classes
def on_model(self, model: Module, sample: Dict[str, Tensor]) -> Tensor:
if self.loss_computed_by_model and model.training:
return model(sample["input"], sample["target"])
return model(sample["input"])
def compute_loss(self, *args: Tensor) -> Tensor:
if self.loss_computed_by_model:
# Note: in case of dict losses obtained
if isinstance(args[0], dict):
return stack([v for _, v in args[0].items()]).mean()
return stack(list(args[0])).sum()
if self.criterion is None:
raise RuntimeError("`criterion` should not be None if `loss_computed_by_model` is False.")
return self.criterion(*args)
def compute_metrics(self, *args: Tuple[Dict[str, Tensor]]) -> Dict[str, float]:
if (
isinstance(args[0], dict)
and "boxes" in args[0]
and "labels" in args[0]
and "scores" in args[0]
and isinstance(args[1], dict)
and "boxes" in args[1]
and "labels" in args[1]
):
mAP, _ = mean_average_precision(
[a["boxes"] for a in args[0]],
[a["labels"] for a in args[0]],
[a["scores"] for a in args[0]],
[a["boxes"] for a in args[1]],
[a["labels"] for a in args[1]],
n_classes=self.num_classes,
threshold=0.000001,
)
return {"mAP": mAP.item()}
return super().compute_metrics(*args)