classification.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. """Classification training task."""
  2. import logging
  3. from typing import Callable, Dict, Optional, Union
  4. import torch
  5. import torch.nn as nn
  6. from .task import TrainingTask
  7. _logger = logging.getLogger(__name__)
  8. class ClassificationTask(TrainingTask):
  9. """Standard supervised classification task.
  10. Simple task that performs a forward pass through the model and computes
  11. the classification loss.
  12. Args:
  13. model: The model to train
  14. criterion: Loss function (e.g., CrossEntropyLoss)
  15. device: Device for task tensors/buffers
  16. dtype: Dtype for task tensors/buffers
  17. verbose: Enable info logging
  18. Example:
  19. >>> task = ClassificationTask(model, nn.CrossEntropyLoss(), device=torch.device('cuda'))
  20. >>> result = task(input, target)
  21. >>> result['loss'].backward()
  22. """
  23. def __init__(
  24. self,
  25. model: nn.Module,
  26. criterion: Union[nn.Module, Callable],
  27. device: Optional[torch.device] = None,
  28. dtype: Optional[torch.dtype] = None,
  29. verbose: bool = True,
  30. ):
  31. super().__init__(device=device, dtype=dtype, verbose=verbose)
  32. self.model = model
  33. self.criterion = criterion
  34. if self.verbose:
  35. loss_name = getattr(criterion, '__name__', None) or type(criterion).__name__
  36. _logger.info(f"ClassificationTask: criterion={loss_name}")
  37. def prepare_distributed(
  38. self,
  39. device_ids: Optional[list] = None,
  40. **ddp_kwargs
  41. ) -> 'ClassificationTask':
  42. """Prepare task for distributed training.
  43. Wraps the model in DistributedDataParallel (DDP).
  44. Args:
  45. device_ids: List of device IDs for DDP (e.g., [local_rank])
  46. **ddp_kwargs: Additional arguments passed to DistributedDataParallel
  47. Returns:
  48. self (for method chaining)
  49. """
  50. from torch.nn.parallel import DistributedDataParallel as DDP
  51. self.model = DDP(self.model, device_ids=device_ids, **ddp_kwargs)
  52. return self
  53. def forward(
  54. self,
  55. input: torch.Tensor,
  56. target: torch.Tensor,
  57. ) -> Dict[str, torch.Tensor]:
  58. """Forward pass through model and compute classification loss.
  59. Args:
  60. input: Input tensor [B, C, H, W]
  61. target: Target labels [B]
  62. Returns:
  63. Dictionary containing:
  64. - 'loss': Classification loss
  65. - 'output': Model logits
  66. """
  67. output = self.model(input)
  68. loss = self.criterion(output, target)
  69. return {
  70. 'loss': loss,
  71. 'output': output,
  72. }