trainer.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. import logging
  18. from typing import Any, Callable, Dict, Optional
  19. # the accelerator library is a requirement for the Trainer
  20. # but it is optional for grousnd base user of kornia.
  21. import torch
  22. from torch.optim import Optimizer, lr_scheduler
  23. from torch.utils.data import DataLoader
  24. try:
  25. from accelerate import Accelerator
  26. except ImportError:
  27. Accelerator = None
  28. from kornia.core import Module, Tensor
  29. from kornia.metrics import AverageMeter
  30. from .utils import Configuration, StatsTracker, TrainerState
  31. callbacks_whitelist = [
  32. # high level functions
  33. "preprocess",
  34. "augmentations",
  35. "evaluate",
  36. "fit",
  37. "fit_epoch",
  38. # events (by calling order)
  39. "on_epoch_start",
  40. "on_before_model",
  41. "on_after_model",
  42. "on_checkpoint",
  43. "on_epoch_end",
  44. ]
  45. class Trainer:
  46. """Base class to train the different models in kornia.
  47. .. warning::
  48. The API is experimental and subject to be modified based on the needs of kornia models.
  49. Args:
  50. model: the nn.Module to be optimized.
  51. train_dataloader: the data loader used in the training loop.
  52. valid_dataloader: the data loader used in the validation loop.
  53. criterion: the nn.Module with the function that computes the loss.
  54. optimizer: the torch optimizer object to be used during the optimization.
  55. scheduler: the torch scheduler object with defiing the scheduling strategy.
  56. accelerator: the Accelerator object to distribute the training.
  57. config: a TrainerConfiguration structure containing the experiment hyper parameters.
  58. callbacks: a dictionary containing the pointers to the functions to overrides. The
  59. main supported hooks are ``evaluate``, ``preprocess``, ``augmentations`` and ``fit``.
  60. .. important::
  61. The API heavily relies on `accelerate <https://github.com/huggingface/accelerate/>`_.
  62. In order to use it, you must: ``pip install kornia[x]``
  63. .. seealso::
  64. Learn how to use the API in our documentation
  65. `here <https://kornia.readthedocs.io/en/latest/get-started/training.html>`_.
  66. """
  67. def __init__(
  68. self,
  69. model: Module,
  70. train_dataloader: DataLoader[Any],
  71. valid_dataloader: DataLoader[Any],
  72. criterion: Optional[Module],
  73. optimizer: Optimizer,
  74. scheduler: lr_scheduler._LRScheduler,
  75. config: Configuration,
  76. callbacks: Optional[Dict[str, Callable[..., None]]] = None,
  77. ) -> None:
  78. # setup the accelerator
  79. if Accelerator is None:
  80. raise ModuleNotFoundError('accelerate library is not installed: pip install "kornia[x]"')
  81. self.accelerator = Accelerator()
  82. # setup the data related objects
  83. self.model = self.accelerator.prepare(model)
  84. self.train_dataloader = self.accelerator.prepare(train_dataloader)
  85. self.valid_dataloader = self.accelerator.prepare(valid_dataloader)
  86. self.criterion = None if criterion is None else criterion.to(self.device)
  87. self.optimizer = self.accelerator.prepare(optimizer)
  88. self.scheduler = scheduler
  89. self.config = config
  90. # configure callbacks
  91. if callbacks is None:
  92. callbacks = {}
  93. for fn_name, fn in callbacks.items():
  94. if fn_name not in callbacks_whitelist:
  95. raise ValueError(f"Not supported: {fn_name}.")
  96. setattr(Trainer, fn_name, fn)
  97. # hyper-params
  98. self.num_epochs = config.num_epochs
  99. self.state = TrainerState.STARTING
  100. self._logger = logging.getLogger("train")
  101. @property
  102. def device(self) -> torch.device:
  103. return self.accelerator.device
  104. def backward(self, loss: Tensor) -> None:
  105. self.accelerator.backward(loss)
  106. def fit_epoch(self, epoch: int) -> None:
  107. # train loop
  108. self.model.train()
  109. losses = AverageMeter()
  110. for sample_id, sample in enumerate(self.train_dataloader):
  111. sample = {"input": sample[0], "target": sample[1]} # new dataset api will come like this
  112. self.optimizer.zero_grad()
  113. # perform the preprocess and augmentations in batch
  114. sample = self.preprocess(sample)
  115. sample = self.augmentations(sample)
  116. sample = self.on_before_model(sample)
  117. # make the actual inference
  118. output = self.on_model(self.model, sample)
  119. self.on_after_model(output, sample) # for debugging purposes
  120. loss = self.compute_loss(output, sample["target"])
  121. self.backward(loss)
  122. self.optimizer.step()
  123. losses.update(loss.item(), len(sample["input"]))
  124. if sample_id % 50 == 0:
  125. self._logger.info(
  126. f"Train: {epoch + 1}/{self.num_epochs} "
  127. f"Sample: {sample_id + 1}/{len(self.train_dataloader)} "
  128. f"Loss: {losses.val:.3f} {losses.avg:.3f}"
  129. )
  130. def fit(self) -> None:
  131. # execute the main loop
  132. # NOTE: Do not change and keep this structure clear for readability.
  133. for epoch in range(self.num_epochs):
  134. # call internally the training loop
  135. # NOTE: override to customize your evaluation routine
  136. self.state = TrainerState.TRAINING
  137. self.fit_epoch(epoch)
  138. # call internally the evaluation loop
  139. # NOTE: override to customize your evaluation routine
  140. self.state = TrainerState.VALIDATE
  141. valid_stats = self.evaluate()
  142. self.on_checkpoint(self.model, epoch, valid_stats)
  143. self.on_epoch_end()
  144. if self.state == TrainerState.TERMINATE:
  145. break
  146. # END OF THE EPOCH
  147. self.scheduler.step()
  148. ...
  149. # events stubs
  150. @torch.no_grad()
  151. def evaluate(self) -> Dict[str, AverageMeter]:
  152. self.model.eval()
  153. stats = StatsTracker()
  154. for sample_id, sample in enumerate(self.valid_dataloader):
  155. sample = {"input": sample[0], "target": sample[1]} # new dataset api will come like this
  156. # perform the preprocess and augmentations in batch
  157. sample = self.preprocess(sample)
  158. sample = self.on_before_model(sample)
  159. # Forward
  160. out = self.on_model(self.model, sample)
  161. self.on_after_model(out, sample)
  162. batch_size: int = len(sample["input"])
  163. # measure accuracy and record loss
  164. # Loss computation
  165. if self.criterion is not None:
  166. val_loss = self.compute_loss(out, sample["target"])
  167. stats.update("losses", val_loss.item(), batch_size)
  168. stats.update_from_dict(self.compute_metrics(out, sample["target"]), batch_size)
  169. if sample_id % 10 == 0:
  170. self._logger.info(f"Test: {sample_id}/{len(self.valid_dataloader)} {stats}")
  171. return stats.as_dict()
  172. def on_epoch_start(self, *args: Any, **kwargs: Any) -> None: ...
  173. def preprocess(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
  174. return x
  175. def augmentations(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
  176. return x
  177. def compute_metrics(self, *args: Any) -> Dict[str, float]:
  178. """Compute metrics during the evaluation."""
  179. return {}
  180. def compute_loss(self, *args: Tensor) -> Tensor:
  181. if self.criterion is None:
  182. raise RuntimeError("`criterion` should not be None.")
  183. return self.criterion(*args)
  184. def on_before_model(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
  185. return x
  186. def on_model(self, model: Module, sample: Dict[str, Tensor]) -> Tensor:
  187. return model(sample["input"])
  188. def on_after_model(self, output: Tensor, sample: Dict[str, Tensor]) -> None: ...
  189. def on_checkpoint(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ...
  190. def on_epoch_end(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ...