trainers.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Callable, Dict, Optional, Tuple
  18. from torch.optim import Optimizer, lr_scheduler
  19. from torch.utils.data import DataLoader
  20. from kornia.core import Module, Tensor, stack
  21. from kornia.metrics import accuracy, mean_average_precision, mean_iou
  22. from .trainer import Trainer
  23. from .utils import Configuration
  24. class ImageClassifierTrainer(Trainer):
  25. """Module to be used for image classification purposes.
  26. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the
  27. :py:func:`~kornia.x.Trainer.evaluate` function implementing a standard
  28. :py:func:`~kornia.metrics.accuracy` topk@[1, 5].
  29. .. seealso::
  30. Learn how to use this class in the following
  31. `example <https://github.com/kornia/tutorials/tree/master/scripts/training/image_classifier/>`__.
  32. """
  33. def compute_metrics(self, *args: Tensor) -> Dict[str, float]:
  34. if len(args) != 2:
  35. raise AssertionError
  36. out, target = args
  37. acc1, acc5 = accuracy(out, target, topk=(1, 5))
  38. return {"top1": acc1.item(), "top5": acc5.item()}
  39. class SemanticSegmentationTrainer(Trainer):
  40. """Module to be used for semantic segmentation purposes.
  41. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the
  42. :py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`.
  43. .. seealso::
  44. Learn how to use this class in the following
  45. `example <https://github.com/kornia/tutorials/tree/master/scripts/training/semantic_segmentation/>`__.
  46. """
  47. def compute_metrics(self, *args: Tensor) -> Dict[str, float]:
  48. if len(args) != 2:
  49. raise AssertionError
  50. out, target = args
  51. iou = mean_iou(out.argmax(1), target, out.shape[1]).mean()
  52. return {"iou": iou.item()}
  53. class ObjectDetectionTrainer(Trainer):
  54. """Module to be used for object detection purposes.
  55. The module subclasses :py:class:`~kornia.x.Trainer` and overrides the
  56. :py:func:`~kornia.x.Trainer.evaluate` function implementing IoU :py:func:`~kornia.metrics.mean_iou`.
  57. .. seealso::
  58. Learn how to use this class in the following
  59. `example <https://github.com/kornia/tutorials/tree/master/scripts/training/object_detection/>`__.
  60. """
  61. def __init__(
  62. self,
  63. model: Module,
  64. train_dataloader: DataLoader[Any],
  65. valid_dataloader: DataLoader[Any],
  66. criterion: Optional[Module],
  67. optimizer: Optimizer,
  68. scheduler: lr_scheduler._LRScheduler,
  69. config: Configuration,
  70. num_classes: int,
  71. callbacks: Optional[Dict[str, Callable[..., None]]] = None,
  72. loss_computed_by_model: Optional[bool] = None,
  73. ) -> None:
  74. if callbacks is None:
  75. callbacks = {}
  76. super().__init__(model, train_dataloader, valid_dataloader, criterion, optimizer, scheduler, config, callbacks)
  77. # TODO: auto-detect if the model is from TorchVision
  78. self.loss_computed_by_model = loss_computed_by_model
  79. self.num_classes = num_classes
  80. def on_model(self, model: Module, sample: Dict[str, Tensor]) -> Tensor:
  81. if self.loss_computed_by_model and model.training:
  82. return model(sample["input"], sample["target"])
  83. return model(sample["input"])
  84. def compute_loss(self, *args: Tensor) -> Tensor:
  85. if self.loss_computed_by_model:
  86. # Note: in case of dict losses obtained
  87. if isinstance(args[0], dict):
  88. return stack([v for _, v in args[0].items()]).mean()
  89. return stack(list(args[0])).sum()
  90. if self.criterion is None:
  91. raise RuntimeError("`criterion` should not be None if `loss_computed_by_model` is False.")
  92. return self.criterion(*args)
  93. def compute_metrics(self, *args: Tuple[Dict[str, Tensor]]) -> Dict[str, float]:
  94. if (
  95. isinstance(args[0], dict)
  96. and "boxes" in args[0]
  97. and "labels" in args[0]
  98. and "scores" in args[0]
  99. and isinstance(args[1], dict)
  100. and "boxes" in args[1]
  101. and "labels" in args[1]
  102. ):
  103. mAP, _ = mean_average_precision(
  104. [a["boxes"] for a in args[0]],
  105. [a["labels"] for a in args[0]],
  106. [a["scores"] for a in args[0]],
  107. [a["boxes"] for a in args[1]],
  108. [a["labels"] for a in args[1]],
  109. n_classes=self.num_classes,
  110. threshold=0.000001,
  111. )
  112. return {"mAP": mAP.item()}
  113. return super().compute_metrics(*args)