callbacks.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from math import inf
  18. from pathlib import Path
  19. from typing import Callable, Dict, Optional, Union
  20. import torch
  21. from kornia.core import Module
  22. from kornia.metrics import AverageMeter
  23. from .utils import TrainerState
  24. def default_filename_fcn(epoch: Union[str, int], metric: Union[str, float]) -> str:
  25. """Generate the filename in the model checkpoint."""
  26. return f"model_epoch={epoch}_metricValue={metric}.pt"
  27. class EarlyStopping:
  28. """Callback that evaluates whether there is improvement in the loss function.
  29. The module track the losses and in case of finish patience sends a termination signal to the trainer.
  30. Args:
  31. monitor: the name of the value to track.
  32. min_delta: the minimum difference between losses to increase the patience counter.
  33. patience: the number of times to wait until the trainer does not terminate.
  34. max_mode: if true metric will be multiply by -1,
  35. turn this flag when increasing metric value is expected for example Accuracy
  36. **Usage example:**
  37. .. code:: python
  38. early_stop = EarlyStopping(
  39. monitor="loss", patience=10
  40. )
  41. trainer = ImageClassifierTrainer(
  42. callbacks={"on_epoch_end", early_stop}
  43. )
  44. """
  45. def __init__(
  46. self,
  47. monitor: str,
  48. min_delta: float = 0.0,
  49. patience: int = 8,
  50. max_mode: bool = False,
  51. ) -> None:
  52. self.monitor = monitor
  53. self.min_delta = min_delta
  54. self.patience = patience
  55. # flag to reverse metric, for example in case of accuracy metric where bigger value is better
  56. # In classical loss functions smaller value = better,
  57. # in case of max_mode training end with metric stable/decreasing
  58. self.max_mode = max_mode
  59. self.counter: int = 0
  60. self.best_score: float = -inf if max_mode else inf
  61. self.early_stop: bool = False
  62. def __call__(self, model: Module, epoch: int, valid_metric: Dict[str, AverageMeter]) -> TrainerState:
  63. score: float = valid_metric[self.monitor].avg
  64. is_best: bool = score > self.best_score if self.max_mode else score < self.best_score
  65. if is_best:
  66. self.best_score = score
  67. self.counter = 0
  68. else:
  69. # Example score = 1.9 best_score = 2.0 min_delta = 0.15
  70. # with max_mode (1.9 > (2.0 - 0.15)) == True
  71. # with min_mode (1.9 < (2.0 + 0.15)) == True
  72. is_within_delta: bool = (
  73. score > (self.best_score - self.min_delta)
  74. if self.max_mode
  75. else score < (self.best_score + self.min_delta)
  76. )
  77. if not is_within_delta:
  78. self.counter += 1
  79. if self.counter >= self.patience:
  80. self.early_stop = True
  81. if self.early_stop:
  82. print(f"[INFO] Early-Stopping the training process. Epoch: {epoch}.")
  83. return TrainerState.TERMINATE
  84. return TrainerState.TRAINING
  85. class ModelCheckpoint:
  86. """Callback that save the model at the end of every epoch.
  87. Args:
  88. filepath: the where to save the mode.
  89. monitor: the name of the value to track.
  90. max_mode: if true metric will be multiply by -1
  91. turn this flag when increasing metric value is expected for example Accuracy
  92. **Usage example:**
  93. .. code:: python
  94. model_checkpoint = ModelCheckpoint(
  95. filepath="./outputs", monitor="loss",
  96. )
  97. trainer = ImageClassifierTrainer(...,
  98. callbacks={"on_checkpoint", model_checkpoint}
  99. )
  100. """
  101. def __init__(
  102. self,
  103. filepath: str,
  104. monitor: str,
  105. filename_fcn: Optional[Callable[..., str]] = None,
  106. max_mode: bool = False,
  107. ) -> None:
  108. self.filepath = filepath
  109. self.monitor = monitor
  110. self._filename_fcn = filename_fcn or default_filename_fcn
  111. # track best model
  112. self.best_metric: float = -inf if max_mode else inf
  113. # flag to reverse metric, for example in case of accuracy metric where bigger value is better
  114. # In classical loss functions smaller value = better,
  115. # In case of max_mode checkpoints are saved if new metric value > old metric value
  116. self.max_mode = max_mode
  117. # create directory
  118. Path(self.filepath).mkdir(parents=True, exist_ok=True)
  119. def __call__(self, model: Module, epoch: int, valid_metric: Dict[str, AverageMeter]) -> None:
  120. valid_metric_value: float = valid_metric[self.monitor].avg
  121. is_best: bool = (
  122. valid_metric_value > self.best_metric if self.max_mode else valid_metric_value < self.best_metric
  123. )
  124. if is_best:
  125. self.best_metric = valid_metric_value
  126. # store old metric and save new model
  127. filename = Path(self.filepath) / self._filename_fcn(epoch, valid_metric_value)
  128. torch.save(model, filename)