| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- import logging
- from typing import Any, Callable, Dict, Optional
- # the accelerator library is a requirement for the Trainer
- # but it is optional for grousnd base user of kornia.
- import torch
- from torch.optim import Optimizer, lr_scheduler
- from torch.utils.data import DataLoader
- try:
- from accelerate import Accelerator
- except ImportError:
- Accelerator = None
- from kornia.core import Module, Tensor
- from kornia.metrics import AverageMeter
- from .utils import Configuration, StatsTracker, TrainerState
- callbacks_whitelist = [
- # high level functions
- "preprocess",
- "augmentations",
- "evaluate",
- "fit",
- "fit_epoch",
- # events (by calling order)
- "on_epoch_start",
- "on_before_model",
- "on_after_model",
- "on_checkpoint",
- "on_epoch_end",
- ]
- class Trainer:
- """Base class to train the different models in kornia.
- .. warning::
- The API is experimental and subject to be modified based on the needs of kornia models.
- Args:
- model: the nn.Module to be optimized.
- train_dataloader: the data loader used in the training loop.
- valid_dataloader: the data loader used in the validation loop.
- criterion: the nn.Module with the function that computes the loss.
- optimizer: the torch optimizer object to be used during the optimization.
- scheduler: the torch scheduler object with defiing the scheduling strategy.
- accelerator: the Accelerator object to distribute the training.
- config: a TrainerConfiguration structure containing the experiment hyper parameters.
- callbacks: a dictionary containing the pointers to the functions to overrides. The
- main supported hooks are ``evaluate``, ``preprocess``, ``augmentations`` and ``fit``.
- .. important::
- The API heavily relies on `accelerate <https://github.com/huggingface/accelerate/>`_.
- In order to use it, you must: ``pip install kornia[x]``
- .. seealso::
- Learn how to use the API in our documentation
- `here <https://kornia.readthedocs.io/en/latest/get-started/training.html>`_.
- """
- def __init__(
- self,
- model: Module,
- train_dataloader: DataLoader[Any],
- valid_dataloader: DataLoader[Any],
- criterion: Optional[Module],
- optimizer: Optimizer,
- scheduler: lr_scheduler._LRScheduler,
- config: Configuration,
- callbacks: Optional[Dict[str, Callable[..., None]]] = None,
- ) -> None:
- # setup the accelerator
- if Accelerator is None:
- raise ModuleNotFoundError('accelerate library is not installed: pip install "kornia[x]"')
- self.accelerator = Accelerator()
- # setup the data related objects
- self.model = self.accelerator.prepare(model)
- self.train_dataloader = self.accelerator.prepare(train_dataloader)
- self.valid_dataloader = self.accelerator.prepare(valid_dataloader)
- self.criterion = None if criterion is None else criterion.to(self.device)
- self.optimizer = self.accelerator.prepare(optimizer)
- self.scheduler = scheduler
- self.config = config
- # configure callbacks
- if callbacks is None:
- callbacks = {}
- for fn_name, fn in callbacks.items():
- if fn_name not in callbacks_whitelist:
- raise ValueError(f"Not supported: {fn_name}.")
- setattr(Trainer, fn_name, fn)
- # hyper-params
- self.num_epochs = config.num_epochs
- self.state = TrainerState.STARTING
- self._logger = logging.getLogger("train")
- @property
- def device(self) -> torch.device:
- return self.accelerator.device
- def backward(self, loss: Tensor) -> None:
- self.accelerator.backward(loss)
- def fit_epoch(self, epoch: int) -> None:
- # train loop
- self.model.train()
- losses = AverageMeter()
- for sample_id, sample in enumerate(self.train_dataloader):
- sample = {"input": sample[0], "target": sample[1]} # new dataset api will come like this
- self.optimizer.zero_grad()
- # perform the preprocess and augmentations in batch
- sample = self.preprocess(sample)
- sample = self.augmentations(sample)
- sample = self.on_before_model(sample)
- # make the actual inference
- output = self.on_model(self.model, sample)
- self.on_after_model(output, sample) # for debugging purposes
- loss = self.compute_loss(output, sample["target"])
- self.backward(loss)
- self.optimizer.step()
- losses.update(loss.item(), len(sample["input"]))
- if sample_id % 50 == 0:
- self._logger.info(
- f"Train: {epoch + 1}/{self.num_epochs} "
- f"Sample: {sample_id + 1}/{len(self.train_dataloader)} "
- f"Loss: {losses.val:.3f} {losses.avg:.3f}"
- )
- def fit(self) -> None:
- # execute the main loop
- # NOTE: Do not change and keep this structure clear for readability.
- for epoch in range(self.num_epochs):
- # call internally the training loop
- # NOTE: override to customize your evaluation routine
- self.state = TrainerState.TRAINING
- self.fit_epoch(epoch)
- # call internally the evaluation loop
- # NOTE: override to customize your evaluation routine
- self.state = TrainerState.VALIDATE
- valid_stats = self.evaluate()
- self.on_checkpoint(self.model, epoch, valid_stats)
- self.on_epoch_end()
- if self.state == TrainerState.TERMINATE:
- break
- # END OF THE EPOCH
- self.scheduler.step()
- ...
- # events stubs
- @torch.no_grad()
- def evaluate(self) -> Dict[str, AverageMeter]:
- self.model.eval()
- stats = StatsTracker()
- for sample_id, sample in enumerate(self.valid_dataloader):
- sample = {"input": sample[0], "target": sample[1]} # new dataset api will come like this
- # perform the preprocess and augmentations in batch
- sample = self.preprocess(sample)
- sample = self.on_before_model(sample)
- # Forward
- out = self.on_model(self.model, sample)
- self.on_after_model(out, sample)
- batch_size: int = len(sample["input"])
- # measure accuracy and record loss
- # Loss computation
- if self.criterion is not None:
- val_loss = self.compute_loss(out, sample["target"])
- stats.update("losses", val_loss.item(), batch_size)
- stats.update_from_dict(self.compute_metrics(out, sample["target"]), batch_size)
- if sample_id % 10 == 0:
- self._logger.info(f"Test: {sample_id}/{len(self.valid_dataloader)} {stats}")
- return stats.as_dict()
- def on_epoch_start(self, *args: Any, **kwargs: Any) -> None: ...
- def preprocess(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
- return x
- def augmentations(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
- return x
- def compute_metrics(self, *args: Any) -> Dict[str, float]:
- """Compute metrics during the evaluation."""
- return {}
- def compute_loss(self, *args: Tensor) -> Tensor:
- if self.criterion is None:
- raise RuntimeError("`criterion` should not be None.")
- return self.criterion(*args)
- def on_before_model(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]:
- return x
- def on_model(self, model: Module, sample: Dict[str, Tensor]) -> Tensor:
- return model(sample["input"])
- def on_after_model(self, output: Tensor, sample: Dict[str, Tensor]) -> None: ...
- def on_checkpoint(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ...
- def on_epoch_end(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ...
|