token_distillation.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """Token-based distillation training task for models with distillation heads."""
  2. import logging
  3. from typing import Dict, Optional, Union
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from timm.models import create_model
  8. from timm.utils import unwrap_model
  9. from .task import TrainingTask
  10. _logger = logging.getLogger(__name__)
  11. class TokenDistillationTeacher(nn.Module):
  12. """Wrapper for a teacher model used in token-based distillation.
  13. Creates and manages a pre-trained teacher model for token distillation,
  14. handling model creation and normalization differences between teacher and student.
  15. Can be created from:
  16. - A model name string (creates the model internally)
  17. - An existing nn.Module (wraps it with the necessary interface)
  18. Args:
  19. model_name_or_module: Either a model name string or an nn.Module
  20. num_classes: Number of output classes (required if model_name_or_module is a string)
  21. in_chans: Number of input channels (used if model_name_or_module is a string)
  22. pretrained_path: Optional path to pretrained weights (used if model_name_or_module is a string)
  23. device: Device to place the model on
  24. dtype: Model dtype (uses float32 if None)
  25. """
  26. def __init__(
  27. self,
  28. model_name_or_module: Union[str, nn.Module],
  29. num_classes: Optional[int] = None,
  30. in_chans: int = 3,
  31. pretrained_path: Optional[str] = None,
  32. device: Optional[torch.device] = None,
  33. dtype: Optional[torch.dtype] = None,
  34. ):
  35. super().__init__()
  36. if isinstance(model_name_or_module, str):
  37. _logger.info(f"Creating token distillation teacher model: '{model_name_or_module}'")
  38. pretrained_kwargs = {'pretrained': True}
  39. if pretrained_path:
  40. pretrained_kwargs['pretrained_cfg_overlay'] = dict(
  41. file=pretrained_path,
  42. num_classes=num_classes,
  43. )
  44. model = create_model(
  45. model_name=model_name_or_module,
  46. num_classes=num_classes,
  47. in_chans=in_chans,
  48. device=device,
  49. dtype=dtype,
  50. **pretrained_kwargs,
  51. )
  52. elif isinstance(model_name_or_module, nn.Module):
  53. model = model_name_or_module
  54. else:
  55. raise TypeError(
  56. f"model_name_or_module must be a string or nn.Module, got {type(model_name_or_module).__name__}"
  57. )
  58. model.eval()
  59. self.model = model
  60. # Get normalization values from pretrained_cfg if available
  61. model_unwrapped = unwrap_model(model)
  62. if hasattr(model_unwrapped, 'pretrained_cfg'):
  63. mean = model_unwrapped.pretrained_cfg.get('mean', (0.485, 0.456, 0.406))
  64. std = model_unwrapped.pretrained_cfg.get('std', (0.229, 0.224, 0.225))
  65. else:
  66. mean = (0.485, 0.456, 0.406)
  67. std = (0.229, 0.224, 0.225)
  68. mean_kd = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1)
  69. std_kd = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1)
  70. self.register_buffer('mean_kd', mean_kd, persistent=False)
  71. self.register_buffer('std_kd', std_kd, persistent=False)
  72. def forward(self, input: torch.Tensor) -> torch.Tensor:
  73. """Forward pass through teacher model.
  74. Args:
  75. input: Input tensor (should already be normalized for teacher)
  76. Returns:
  77. Teacher logits
  78. """
  79. return self.model(input)
  80. def normalize_input(
  81. self,
  82. input: torch.Tensor,
  83. student_mean: Optional[torch.Tensor] = None,
  84. student_std: Optional[torch.Tensor] = None,
  85. ) -> torch.Tensor:
  86. """Normalize input to match teacher's expected normalization.
  87. Args:
  88. input: Input tensor (already normalized for student)
  89. student_mean: Student normalization mean buffer [1, 3, 1, 1]
  90. student_std: Student normalization std buffer [1, 3, 1, 1]
  91. Returns:
  92. Input tensor normalized for the teacher model
  93. """
  94. if student_mean is None or student_std is None:
  95. return input
  96. if torch.equal(student_mean, self.mean_kd) and torch.equal(student_std, self.std_kd):
  97. return input
  98. return (input * student_std + student_mean - self.mean_kd) / self.std_kd
  99. class TokenDistillationTask(TrainingTask):
  100. """Token-based distillation task for models with distillation heads.
  101. For models like DeiT that have a dedicated distillation token/head that returns
  102. a tuple (main_logits, dist_logits) when distilled_training is enabled. The main
  103. head is trained against ground truth labels while the distillation head matches
  104. teacher outputs.
  105. Supports two distillation modes:
  106. - 'soft': KL divergence with temperature scaling (default)
  107. - 'hard': Cross-entropy with teacher's hard predictions (argmax)
  108. Loss weighting supports two modes:
  109. 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss
  110. 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss
  111. (used when only task_loss_weight is specified)
  112. Args:
  113. student_model: Student model with set_distilled_training() method
  114. teacher_model: Teacher model - can be a model name string, nn.Module, or TokenDistillationTeacher
  115. criterion: Task loss function for main head (default: CrossEntropyLoss)
  116. teacher_pretrained_path: Path to teacher pretrained weights (used when teacher_model is a string)
  117. distill_type: 'soft' for KL-div or 'hard' for CE with teacher argmax
  118. distill_loss_weight: Weight for distillation loss
  119. task_loss_weight: Weight for task loss
  120. temperature: Softmax temperature for soft distillation (ignored for hard)
  121. device: Device for task tensors/buffers
  122. dtype: Dtype for task tensors/buffers
  123. verbose: Enable info logging
  124. Example:
  125. >>> # With model name string (num_classes/in_chans inferred from student)
  126. >>> task = TokenDistillationTask(
  127. ... student_model=model, teacher_model='deit_base_patch16_224',
  128. ... criterion=nn.CrossEntropyLoss(),
  129. ... distill_type='soft', temperature=3.0, task_loss_weight=0.5,
  130. ... device=torch.device('cuda'),
  131. ... )
  132. >>> # With raw model
  133. >>> task = TokenDistillationTask(
  134. ... student_model=model, teacher_model=my_teacher_model,
  135. ... criterion=nn.CrossEntropyLoss(),
  136. ... distill_type='hard', task_loss_weight=0.5,
  137. ... )
  138. """
  139. def __init__(
  140. self,
  141. student_model: nn.Module,
  142. teacher_model: Union[str, nn.Module, TokenDistillationTeacher],
  143. criterion: Optional[nn.Module] = None,
  144. teacher_pretrained_path: Optional[str] = None,
  145. distill_type: str = 'soft',
  146. distill_loss_weight: Optional[float] = None,
  147. task_loss_weight: Optional[float] = None,
  148. temperature: float = 1.0,
  149. device: Optional[torch.device] = None,
  150. dtype: Optional[torch.dtype] = None,
  151. verbose: bool = True,
  152. ):
  153. super().__init__(device=device, dtype=dtype, verbose=verbose)
  154. # Validate model has set_distilled_training method
  155. student_unwrapped = unwrap_model(student_model)
  156. if not hasattr(student_unwrapped, 'set_distilled_training'):
  157. raise ValueError(
  158. f"Model {student_unwrapped.__class__.__name__} does not have 'set_distilled_training' method. "
  159. "TokenDistillationTask requires a model with a distillation head (e.g., DeiT distilled variants)."
  160. )
  161. # Enable distilled training mode
  162. student_unwrapped.set_distilled_training(True)
  163. # Handle different teacher input types
  164. if isinstance(teacher_model, TokenDistillationTeacher):
  165. teacher = teacher_model
  166. elif isinstance(teacher_model, str) or isinstance(teacher_model, nn.Module):
  167. # Get num_classes and in_chans from student
  168. num_classes = student_unwrapped.num_classes
  169. in_chans = student_unwrapped.in_chans
  170. teacher = TokenDistillationTeacher(
  171. model_name_or_module=teacher_model,
  172. num_classes=num_classes,
  173. in_chans=in_chans,
  174. pretrained_path=teacher_pretrained_path,
  175. device=self.device,
  176. dtype=self.dtype,
  177. )
  178. else:
  179. raise TypeError(
  180. f"teacher_model must be a model name string, nn.Module, or TokenDistillationTeacher, "
  181. f"got {type(teacher_model).__name__}"
  182. )
  183. self.student = student_model
  184. self.teacher = teacher
  185. self.criterion = criterion if criterion is not None else nn.CrossEntropyLoss()
  186. self.distill_type = distill_type
  187. self.temperature = temperature
  188. if distill_type not in ('soft', 'hard'):
  189. raise ValueError(f"Unsupported distill_type '{distill_type}'. Must be 'soft' or 'hard'.")
  190. # Register student normalization values as non-persistent buffers
  191. student_mean = torch.tensor(
  192. student_unwrapped.pretrained_cfg['mean'],
  193. device=self.device,
  194. dtype=self.dtype,
  195. ).view(1, -1, 1, 1)
  196. student_std = torch.tensor(
  197. student_unwrapped.pretrained_cfg['std'],
  198. device=self.device,
  199. dtype=self.dtype,
  200. ).view(1, -1, 1, 1)
  201. self.register_buffer('student_mean', student_mean, persistent=False)
  202. self.register_buffer('student_std', student_std, persistent=False)
  203. # Determine weighting mode
  204. if distill_loss_weight is not None:
  205. # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set)
  206. self.distill_loss_weight = distill_loss_weight
  207. self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0
  208. if self.verbose:
  209. _logger.info(
  210. f"TokenDistillationTask: Independent weights - "
  211. f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}"
  212. )
  213. elif task_loss_weight is not None:
  214. # Mode 2: only task_weight specified - complementary mode (distill = 1 - task)
  215. self.task_loss_weight = task_loss_weight
  216. self.distill_loss_weight = 1.0 - task_loss_weight
  217. if self.verbose:
  218. _logger.info(
  219. f"TokenDistillationTask: Complementary mode - "
  220. f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}"
  221. )
  222. else:
  223. # Mode 3: neither specified - equal weights (both 1.0)
  224. self.distill_loss_weight = 1.0
  225. self.task_loss_weight = 1.0
  226. if self.verbose:
  227. _logger.info(
  228. f"TokenDistillationTask: Default equal weights - "
  229. f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}"
  230. )
  231. if self.verbose:
  232. _logger.info(
  233. f"TokenDistillationTask: distill_type={distill_type}, temperature={temperature}"
  234. )
  235. def prepare_distributed(
  236. self,
  237. device_ids: Optional[list] = None,
  238. **ddp_kwargs
  239. ) -> 'TokenDistillationTask':
  240. """Prepare task for distributed training.
  241. Wraps the student model in DistributedDataParallel (DDP) while leaving
  242. the frozen teacher model unwrapped.
  243. Args:
  244. device_ids: List of device IDs for DDP (e.g., [local_rank])
  245. **ddp_kwargs: Additional arguments passed to DistributedDataParallel
  246. Returns:
  247. self (for method chaining)
  248. """
  249. from torch.nn.parallel import DistributedDataParallel as DDP
  250. for param in self.teacher.parameters():
  251. param.requires_grad = False
  252. self.student = DDP(self.student, device_ids=device_ids, **ddp_kwargs)
  253. return self
  254. def forward(
  255. self,
  256. input: torch.Tensor,
  257. target: torch.Tensor,
  258. ) -> Dict[str, torch.Tensor]:
  259. """Forward pass with token distillation.
  260. Args:
  261. input: Input tensor [B, C, H, W]
  262. target: Target labels [B]
  263. Returns:
  264. Dictionary containing:
  265. - 'loss': Combined training loss (task + distillation)
  266. - 'output': Main head logits (for metrics)
  267. - 'task_loss': Classification loss component
  268. - 'distill_loss': Distillation loss component
  269. """
  270. # Student forward pass - returns tuple (main_logits, dist_logits)
  271. student_output = self.student(input)
  272. main_logits, dist_logits = student_output
  273. # Compute task loss on main head
  274. task_loss = self.criterion(main_logits, target)
  275. # Teacher forward pass (no gradient)
  276. with torch.no_grad():
  277. input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std)
  278. teacher_logits = self.teacher(input_kd.detach())
  279. # Compute distillation loss on distillation head
  280. if self.distill_type == 'soft':
  281. prob_s = F.log_softmax(dist_logits / self.temperature, dim=-1)
  282. prob_t = F.log_softmax(teacher_logits / self.temperature, dim=-1)
  283. distill_loss = F.kl_div(prob_s, prob_t, reduction='batchmean', log_target=True) * (self.temperature ** 2)
  284. else:
  285. teacher_hard = teacher_logits.argmax(dim=-1)
  286. distill_loss = F.cross_entropy(dist_logits, teacher_hard)
  287. total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * distill_loss
  288. return {
  289. 'loss': total_loss,
  290. 'output': main_logits,
  291. 'task_loss': task_loss,
  292. 'distill_loss': distill_loss,
  293. }