| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- """Base training task abstraction.
- This module provides the base TrainingTask class that encapsulates a complete
- forward pass including loss computation. Tasks return a dictionary with loss
- components and outputs for logging.
- """
- from typing import Dict, Optional
- import torch
- import torch.nn as nn
- class TrainingTask(nn.Module):
- """Base class for training tasks.
- A training task encapsulates a complete forward pass including loss computation.
- Tasks return a dictionary containing the training loss and other components for logging.
- The returned dictionary must contain:
- - 'loss': The training loss for backward pass (required)
- - 'output': Model output/logits for metric computation (recommended)
- - Other task-specific loss components for logging (optional)
- Args:
- device: Device for task tensors/buffers (defaults to cpu)
- dtype: Dtype for task tensors/buffers (defaults to torch default)
- verbose: Enable info logging
- Example:
- >>> task = SomeTask(model, criterion, device=torch.device('cuda'))
- >>>
- >>> # Prepare for distributed training (if needed)
- >>> if distributed:
- >>> task.prepare_distributed(device_ids=[local_rank])
- >>>
- >>> # Training loop
- >>> result = task(input, target)
- >>> result['loss'].backward()
- """
- def __init__(
- self,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- verbose: bool = True,
- ):
- super().__init__()
- self.device = device if device is not None else torch.device('cpu')
- self.dtype = dtype if dtype is not None else torch.get_default_dtype()
- self.verbose = verbose
- def to(self, *args, **kwargs):
- """Move task to device/dtype, keeping self.device and self.dtype in sync."""
- dummy = torch.empty(0).to(*args, **kwargs)
- self.device = dummy.device
- self.dtype = dummy.dtype
- return super().to(*args, **kwargs)
- def prepare_distributed(
- self,
- device_ids: Optional[list] = None,
- **ddp_kwargs
- ) -> 'TrainingTask':
- """Prepare task for distributed training.
- This method wraps trainable components in DistributedDataParallel (DDP)
- while leaving non-trainable components (like frozen teacher models) unwrapped.
- Should be called after task initialization but before training loop.
- Args:
- device_ids: List of device IDs for DDP (e.g., [local_rank])
- **ddp_kwargs: Additional arguments passed to DistributedDataParallel
- Returns:
- self (for method chaining)
- Example:
- >>> task = LogitDistillationTask(student, teacher, criterion)
- >>> task.prepare_distributed(device_ids=[args.local_rank])
- >>> task = torch.compile(task) # Compile after DDP
- """
- # Default implementation - subclasses override if they need DDP
- return self
- def forward(
- self,
- input: torch.Tensor,
- target: torch.Tensor,
- ) -> Dict[str, torch.Tensor]:
- """Perform forward pass and compute loss.
- Args:
- input: Input tensor [B, C, H, W]
- target: Target labels [B]
- Returns:
- Dictionary with at least 'loss' key containing the training loss
- """
- raise NotImplementedError
|