__init__.py 712 B

1234567891011121314151617181920
  1. """Training task abstractions for timm.
  2. This module provides task-based abstractions for training loops where each task
  3. encapsulates both the forward pass and loss computation, returning a dictionary
  4. with loss components and outputs for logging.
  5. """
  6. from .task import TrainingTask
  7. from .classification import ClassificationTask
  8. from .distillation import DistillationTeacher, LogitDistillationTask, FeatureDistillationTask
  9. from .token_distillation import TokenDistillationTeacher, TokenDistillationTask
  10. __all__ = [
  11. 'TrainingTask',
  12. 'ClassificationTask',
  13. 'DistillationTeacher',
  14. 'LogitDistillationTask',
  15. 'FeatureDistillationTask',
  16. 'TokenDistillationTeacher',
  17. 'TokenDistillationTask',
  18. ]