task.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """Base training task abstraction.
  2. This module provides the base TrainingTask class that encapsulates a complete
  3. forward pass including loss computation. Tasks return a dictionary with loss
  4. components and outputs for logging.
  5. """
  6. from typing import Dict, Optional
  7. import torch
  8. import torch.nn as nn
  9. class TrainingTask(nn.Module):
  10. """Base class for training tasks.
  11. A training task encapsulates a complete forward pass including loss computation.
  12. Tasks return a dictionary containing the training loss and other components for logging.
  13. The returned dictionary must contain:
  14. - 'loss': The training loss for backward pass (required)
  15. - 'output': Model output/logits for metric computation (recommended)
  16. - Other task-specific loss components for logging (optional)
  17. Args:
  18. device: Device for task tensors/buffers (defaults to cpu)
  19. dtype: Dtype for task tensors/buffers (defaults to torch default)
  20. verbose: Enable info logging
  21. Example:
  22. >>> task = SomeTask(model, criterion, device=torch.device('cuda'))
  23. >>>
  24. >>> # Prepare for distributed training (if needed)
  25. >>> if distributed:
  26. >>> task.prepare_distributed(device_ids=[local_rank])
  27. >>>
  28. >>> # Training loop
  29. >>> result = task(input, target)
  30. >>> result['loss'].backward()
  31. """
  32. def __init__(
  33. self,
  34. device: Optional[torch.device] = None,
  35. dtype: Optional[torch.dtype] = None,
  36. verbose: bool = True,
  37. ):
  38. super().__init__()
  39. self.device = device if device is not None else torch.device('cpu')
  40. self.dtype = dtype if dtype is not None else torch.get_default_dtype()
  41. self.verbose = verbose
  42. def to(self, *args, **kwargs):
  43. """Move task to device/dtype, keeping self.device and self.dtype in sync."""
  44. dummy = torch.empty(0).to(*args, **kwargs)
  45. self.device = dummy.device
  46. self.dtype = dummy.dtype
  47. return super().to(*args, **kwargs)
  48. def prepare_distributed(
  49. self,
  50. device_ids: Optional[list] = None,
  51. **ddp_kwargs
  52. ) -> 'TrainingTask':
  53. """Prepare task for distributed training.
  54. This method wraps trainable components in DistributedDataParallel (DDP)
  55. while leaving non-trainable components (like frozen teacher models) unwrapped.
  56. Should be called after task initialization but before training loop.
  57. Args:
  58. device_ids: List of device IDs for DDP (e.g., [local_rank])
  59. **ddp_kwargs: Additional arguments passed to DistributedDataParallel
  60. Returns:
  61. self (for method chaining)
  62. Example:
  63. >>> task = LogitDistillationTask(student, teacher, criterion)
  64. >>> task.prepare_distributed(device_ids=[args.local_rank])
  65. >>> task = torch.compile(task) # Compile after DDP
  66. """
  67. # Default implementation - subclasses override if they need DDP
  68. return self
  69. def forward(
  70. self,
  71. input: torch.Tensor,
  72. target: torch.Tensor,
  73. ) -> Dict[str, torch.Tensor]:
  74. """Perform forward pass and compute loss.
  75. Args:
  76. input: Input tensor [B, C, H, W]
  77. target: Target labels [B]
  78. Returns:
  79. Dictionary with at least 'loss' key containing the training loss
  80. """
  81. raise NotImplementedError