# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging from typing import Any, Callable, Dict, Optional # the accelerator library is a requirement for the Trainer # but it is optional for grousnd base user of kornia. import torch from torch.optim import Optimizer, lr_scheduler from torch.utils.data import DataLoader try: from accelerate import Accelerator except ImportError: Accelerator = None from kornia.core import Module, Tensor from kornia.metrics import AverageMeter from .utils import Configuration, StatsTracker, TrainerState callbacks_whitelist = [ # high level functions "preprocess", "augmentations", "evaluate", "fit", "fit_epoch", # events (by calling order) "on_epoch_start", "on_before_model", "on_after_model", "on_checkpoint", "on_epoch_end", ] class Trainer: """Base class to train the different models in kornia. .. warning:: The API is experimental and subject to be modified based on the needs of kornia models. Args: model: the nn.Module to be optimized. train_dataloader: the data loader used in the training loop. valid_dataloader: the data loader used in the validation loop. criterion: the nn.Module with the function that computes the loss. optimizer: the torch optimizer object to be used during the optimization. scheduler: the torch scheduler object with defiing the scheduling strategy. accelerator: the Accelerator object to distribute the training. config: a TrainerConfiguration structure containing the experiment hyper parameters. callbacks: a dictionary containing the pointers to the functions to overrides. The main supported hooks are ``evaluate``, ``preprocess``, ``augmentations`` and ``fit``. .. important:: The API heavily relies on `accelerate `_. In order to use it, you must: ``pip install kornia[x]`` .. seealso:: Learn how to use the API in our documentation `here `_. """ def __init__( self, model: Module, train_dataloader: DataLoader[Any], valid_dataloader: DataLoader[Any], criterion: Optional[Module], optimizer: Optimizer, scheduler: lr_scheduler._LRScheduler, config: Configuration, callbacks: Optional[Dict[str, Callable[..., None]]] = None, ) -> None: # setup the accelerator if Accelerator is None: raise ModuleNotFoundError('accelerate library is not installed: pip install "kornia[x]"') self.accelerator = Accelerator() # setup the data related objects self.model = self.accelerator.prepare(model) self.train_dataloader = self.accelerator.prepare(train_dataloader) self.valid_dataloader = self.accelerator.prepare(valid_dataloader) self.criterion = None if criterion is None else criterion.to(self.device) self.optimizer = self.accelerator.prepare(optimizer) self.scheduler = scheduler self.config = config # configure callbacks if callbacks is None: callbacks = {} for fn_name, fn in callbacks.items(): if fn_name not in callbacks_whitelist: raise ValueError(f"Not supported: {fn_name}.") setattr(Trainer, fn_name, fn) # hyper-params self.num_epochs = config.num_epochs self.state = TrainerState.STARTING self._logger = logging.getLogger("train") @property def device(self) -> torch.device: return self.accelerator.device def backward(self, loss: Tensor) -> None: self.accelerator.backward(loss) def fit_epoch(self, epoch: int) -> None: # train loop self.model.train() losses = AverageMeter() for sample_id, sample in enumerate(self.train_dataloader): sample = {"input": sample[0], "target": sample[1]} # new dataset api will come like this self.optimizer.zero_grad() # perform the preprocess and augmentations in batch sample = self.preprocess(sample) sample = self.augmentations(sample) sample = self.on_before_model(sample) # make the actual inference output = self.on_model(self.model, sample) self.on_after_model(output, sample) # for debugging purposes loss = self.compute_loss(output, sample["target"]) self.backward(loss) self.optimizer.step() losses.update(loss.item(), len(sample["input"])) if sample_id % 50 == 0: self._logger.info( f"Train: {epoch + 1}/{self.num_epochs} " f"Sample: {sample_id + 1}/{len(self.train_dataloader)} " f"Loss: {losses.val:.3f} {losses.avg:.3f}" ) def fit(self) -> None: # execute the main loop # NOTE: Do not change and keep this structure clear for readability. for epoch in range(self.num_epochs): # call internally the training loop # NOTE: override to customize your evaluation routine self.state = TrainerState.TRAINING self.fit_epoch(epoch) # call internally the evaluation loop # NOTE: override to customize your evaluation routine self.state = TrainerState.VALIDATE valid_stats = self.evaluate() self.on_checkpoint(self.model, epoch, valid_stats) self.on_epoch_end() if self.state == TrainerState.TERMINATE: break # END OF THE EPOCH self.scheduler.step() ... # events stubs @torch.no_grad() def evaluate(self) -> Dict[str, AverageMeter]: self.model.eval() stats = StatsTracker() for sample_id, sample in enumerate(self.valid_dataloader): sample = {"input": sample[0], "target": sample[1]} # new dataset api will come like this # perform the preprocess and augmentations in batch sample = self.preprocess(sample) sample = self.on_before_model(sample) # Forward out = self.on_model(self.model, sample) self.on_after_model(out, sample) batch_size: int = len(sample["input"]) # measure accuracy and record loss # Loss computation if self.criterion is not None: val_loss = self.compute_loss(out, sample["target"]) stats.update("losses", val_loss.item(), batch_size) stats.update_from_dict(self.compute_metrics(out, sample["target"]), batch_size) if sample_id % 10 == 0: self._logger.info(f"Test: {sample_id}/{len(self.valid_dataloader)} {stats}") return stats.as_dict() def on_epoch_start(self, *args: Any, **kwargs: Any) -> None: ... def preprocess(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x def augmentations(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x def compute_metrics(self, *args: Any) -> Dict[str, float]: """Compute metrics during the evaluation.""" return {} def compute_loss(self, *args: Tensor) -> Tensor: if self.criterion is None: raise RuntimeError("`criterion` should not be None.") return self.criterion(*args) def on_before_model(self, x: Dict[str, Tensor]) -> Dict[str, Tensor]: return x def on_model(self, model: Module, sample: Dict[str, Tensor]) -> Tensor: return model(sample["input"]) def on_after_model(self, output: Tensor, sample: Dict[str, Tensor]) -> None: ... def on_checkpoint(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ... def on_epoch_end(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ...