distillation.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. """Knowledge distillation training tasks and components."""
  2. import logging
  3. from typing import Dict, Optional, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. from timm.models import create_model
  8. from timm.utils import unwrap_model
  9. from .task import TrainingTask
  10. _logger = logging.getLogger(__name__)
  11. class DistillationTeacher(nn.Module):
  12. """Wrapper for a teacher model used in knowledge distillation.
  13. Creates and manages a pre-trained teacher model for knowledge distillation,
  14. handling model creation and normalization differences between teacher and student.
  15. Can be created from:
  16. - A model name string (creates the model internally with pretrained weights)
  17. - An existing nn.Module (wraps it with the necessary interface)
  18. Args:
  19. model_name_or_module: Either a model name string or an nn.Module
  20. num_classes: Number of output classes (required if model_name_or_module is a string)
  21. in_chans: Number of input channels (used if model_name_or_module is a string)
  22. pretrained_path: Optional path to pretrained weights (used if model_name_or_module is a string)
  23. device: Device to place the model on
  24. dtype: Model dtype (uses float32 if None)
  25. """
  26. def __init__(
  27. self,
  28. model_name_or_module: Union[str, nn.Module],
  29. num_classes: Optional[int] = None,
  30. in_chans: int = 3,
  31. pretrained_path: Optional[str] = None,
  32. device: Optional[torch.device] = None,
  33. dtype: Optional[torch.dtype] = None,
  34. ):
  35. super().__init__()
  36. if isinstance(model_name_or_module, str):
  37. _logger.info(f"Creating KD teacher model: '{model_name_or_module}'")
  38. pretrained_kwargs = {'pretrained': True}
  39. if pretrained_path:
  40. pretrained_kwargs['pretrained_cfg_overlay'] = dict(
  41. file=pretrained_path,
  42. num_classes=num_classes,
  43. )
  44. model = create_model(
  45. model_name=model_name_or_module,
  46. num_classes=num_classes,
  47. in_chans=in_chans,
  48. device=device,
  49. dtype=dtype,
  50. **pretrained_kwargs,
  51. )
  52. elif isinstance(model_name_or_module, nn.Module):
  53. model = model_name_or_module
  54. else:
  55. raise TypeError(
  56. f"model_name_or_module must be a string or nn.Module, got {type(model_name_or_module).__name__}"
  57. )
  58. model.eval()
  59. self.model = model
  60. # Get normalization values from pretrained_cfg if available
  61. model_unwrapped = unwrap_model(model)
  62. if hasattr(model_unwrapped, 'pretrained_cfg'):
  63. mean = model_unwrapped.pretrained_cfg.get('mean', (0.485, 0.456, 0.406))
  64. std = model_unwrapped.pretrained_cfg.get('std', (0.229, 0.224, 0.225))
  65. else:
  66. mean = (0.485, 0.456, 0.406)
  67. std = (0.229, 0.224, 0.225)
  68. mean_kd = torch.tensor(mean, device=device, dtype=dtype).view(1, -1, 1, 1)
  69. std_kd = torch.tensor(std, device=device, dtype=dtype).view(1, -1, 1, 1)
  70. self.register_buffer('mean_kd', mean_kd, persistent=False)
  71. self.register_buffer('std_kd', std_kd, persistent=False)
  72. def forward(
  73. self,
  74. input: torch.Tensor,
  75. return_features: bool = False,
  76. ) -> torch.Tensor:
  77. """Forward pass through teacher model.
  78. Args:
  79. input: Input tensor (should already be normalized for teacher)
  80. return_features: Whether to return pooled pre-logits features instead of logits
  81. Returns:
  82. Logits or pooled pre-logits features depending on return_features flag
  83. """
  84. if return_features:
  85. if not hasattr(self.model, 'forward_features') or not hasattr(self.model, 'forward_head'):
  86. raise ValueError(
  87. f"Model {self.model.__class__.__name__} does not support feature extraction. "
  88. "Ensure the model has 'forward_features' and 'forward_head' methods."
  89. )
  90. feature_map = self.model.forward_features(input)
  91. return self.model.forward_head(feature_map, pre_logits=True)
  92. else:
  93. return self.model(input)
  94. def normalize_input(
  95. self,
  96. input: torch.Tensor,
  97. student_mean: Optional[torch.Tensor] = None,
  98. student_std: Optional[torch.Tensor] = None,
  99. ) -> torch.Tensor:
  100. """Normalize input to match teacher's expected normalization.
  101. Args:
  102. input: Input tensor (already normalized for student)
  103. student_mean: Student normalization mean buffer [1, 3, 1, 1]
  104. student_std: Student normalization std buffer [1, 3, 1, 1]
  105. Returns:
  106. Input tensor normalized for the teacher model
  107. """
  108. if student_mean is None or student_std is None:
  109. return input
  110. if torch.equal(student_mean, self.mean_kd) and torch.equal(student_std, self.std_kd):
  111. return input
  112. return (input * student_std + student_mean - self.mean_kd) / self.std_kd
  113. def _resolve_teacher(
  114. teacher: Union[str, nn.Module, DistillationTeacher],
  115. student_model: nn.Module,
  116. pretrained_path: Optional[str],
  117. device: Optional[torch.device],
  118. dtype: Optional[torch.dtype],
  119. ) -> DistillationTeacher:
  120. """Resolve teacher input to a DistillationTeacher instance.
  121. Args:
  122. teacher: Model name string, nn.Module, or DistillationTeacher
  123. student_model: Student model to infer num_classes/in_chans from
  124. pretrained_path: Optional path to teacher pretrained weights
  125. device: Device for teacher
  126. dtype: Dtype for teacher
  127. Returns:
  128. DistillationTeacher instance
  129. """
  130. if isinstance(teacher, DistillationTeacher):
  131. return teacher
  132. # Get num_classes and in_chans from student
  133. student_unwrapped = unwrap_model(student_model)
  134. num_classes = student_unwrapped.num_classes
  135. in_chans = student_unwrapped.in_chans
  136. return DistillationTeacher(
  137. model_name_or_module=teacher,
  138. num_classes=num_classes,
  139. in_chans=in_chans,
  140. pretrained_path=pretrained_path,
  141. device=device,
  142. dtype=dtype,
  143. )
  144. class LogitDistillationTask(TrainingTask):
  145. """Logit-based knowledge distillation task.
  146. Performs distillation by matching student and teacher output logits using
  147. KL divergence with temperature scaling.
  148. Loss weighting supports two modes:
  149. 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss
  150. 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss
  151. (used when only task_loss_weight is specified)
  152. Args:
  153. student_model: Student model to train
  154. teacher_model: Teacher model - can be a model name string, nn.Module, or DistillationTeacher
  155. criterion: Task loss function (default: CrossEntropyLoss)
  156. teacher_pretrained_path: Path to teacher pretrained weights (used when teacher_model is a string)
  157. loss_type: Type of distillation loss (currently only 'kl' supported)
  158. distill_loss_weight: Weight for distillation loss
  159. task_loss_weight: Weight for task loss
  160. temperature: Softmax temperature for distillation (typical values: 1-4)
  161. device: Device for task tensors/buffers
  162. dtype: Dtype for task tensors/buffers
  163. verbose: Enable info logging
  164. Example:
  165. >>> # With model name string (num_classes/in_chans inferred from student)
  166. >>> task = LogitDistillationTask(
  167. ... student_model=model, teacher_model='resnet50',
  168. ... criterion=nn.CrossEntropyLoss(),
  169. ... task_loss_weight=0.3, temperature=4.0,
  170. ... device=torch.device('cuda'),
  171. ... )
  172. >>> # With raw model
  173. >>> task = LogitDistillationTask(
  174. ... student_model=model, teacher_model=my_teacher_model,
  175. ... criterion=nn.CrossEntropyLoss(),
  176. ... task_loss_weight=0.3, temperature=4.0,
  177. ... )
  178. """
  179. def __init__(
  180. self,
  181. student_model: nn.Module,
  182. teacher_model: Union[str, nn.Module, DistillationTeacher],
  183. criterion: Optional[nn.Module] = None,
  184. teacher_pretrained_path: Optional[str] = None,
  185. loss_type: str = 'kl',
  186. distill_loss_weight: Optional[float] = None,
  187. task_loss_weight: Optional[float] = None,
  188. temperature: float = 1.0,
  189. device: Optional[torch.device] = None,
  190. dtype: Optional[torch.dtype] = None,
  191. verbose: bool = True,
  192. ):
  193. super().__init__(device=device, dtype=dtype, verbose=verbose)
  194. # Resolve teacher to DistillationTeacher
  195. teacher = _resolve_teacher(
  196. teacher_model,
  197. student_model,
  198. teacher_pretrained_path,
  199. self.device,
  200. self.dtype,
  201. )
  202. self.student = student_model
  203. self.teacher = teacher
  204. self.criterion = criterion if criterion is not None else nn.CrossEntropyLoss()
  205. self.loss_type = loss_type
  206. self.temperature = temperature
  207. if loss_type != 'kl':
  208. raise ValueError(f"Unsupported loss_type '{loss_type}'. Currently only 'kl' is supported.")
  209. # Register student normalization values as non-persistent buffers
  210. student_unwrapped = unwrap_model(student_model)
  211. student_mean = torch.tensor(
  212. student_unwrapped.pretrained_cfg['mean'],
  213. device=self.device,
  214. dtype=self.dtype,
  215. ).view(1, -1, 1, 1)
  216. student_std = torch.tensor(
  217. student_unwrapped.pretrained_cfg['std'],
  218. device=self.device,
  219. dtype=self.dtype,
  220. ).view(1, -1, 1, 1)
  221. self.register_buffer('student_mean', student_mean, persistent=False)
  222. self.register_buffer('student_std', student_std, persistent=False)
  223. # Determine weighting mode
  224. if distill_loss_weight is not None:
  225. # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set)
  226. self.distill_loss_weight = distill_loss_weight
  227. self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0
  228. if self.verbose:
  229. _logger.info(
  230. f"LogitDistillationTask: Independent weights - "
  231. f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}"
  232. )
  233. elif task_loss_weight is not None:
  234. # Mode 2: only task_weight specified - complementary mode (distill = 1 - task)
  235. self.task_loss_weight = task_loss_weight
  236. self.distill_loss_weight = 1.0 - task_loss_weight
  237. if self.verbose:
  238. _logger.info(
  239. f"LogitDistillationTask: Complementary mode - "
  240. f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}"
  241. )
  242. else:
  243. # Mode 3: neither specified - equal weights (both 1.0)
  244. self.distill_loss_weight = 1.0
  245. self.task_loss_weight = 1.0
  246. if self.verbose:
  247. _logger.info(
  248. f"LogitDistillationTask: Default equal weights - "
  249. f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}"
  250. )
  251. if self.verbose:
  252. _logger.info(
  253. f"LogitDistillationTask: loss_type={loss_type}, temperature={temperature}"
  254. )
  255. def prepare_distributed(
  256. self,
  257. device_ids: Optional[list] = None,
  258. **ddp_kwargs
  259. ) -> 'LogitDistillationTask':
  260. """Prepare task for distributed training.
  261. Wraps the student model in DistributedDataParallel (DDP) while leaving
  262. the frozen teacher model unwrapped.
  263. Args:
  264. device_ids: List of device IDs for DDP (e.g., [local_rank])
  265. **ddp_kwargs: Additional arguments passed to DistributedDataParallel
  266. Returns:
  267. self (for method chaining)
  268. """
  269. from torch.nn.parallel import DistributedDataParallel as DDP
  270. for param in self.teacher.parameters():
  271. param.requires_grad = False
  272. self.student = DDP(self.student, device_ids=device_ids, **ddp_kwargs)
  273. return self
  274. def forward(
  275. self,
  276. input: torch.Tensor,
  277. target: torch.Tensor,
  278. ) -> Dict[str, torch.Tensor]:
  279. """Forward pass with logit distillation.
  280. Args:
  281. input: Input tensor [B, C, H, W]
  282. target: Target labels [B]
  283. Returns:
  284. Dictionary containing:
  285. - 'loss': Combined training loss (task + distillation)
  286. - 'output': Student logits (for metrics)
  287. - 'task_loss': Classification loss component
  288. - 'kd_loss': Logit distillation loss component
  289. """
  290. student_logits = self.student(input)
  291. task_loss = self.criterion(student_logits, target)
  292. with torch.no_grad():
  293. input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std)
  294. teacher_logits = self.teacher(input_kd.detach(), return_features=False)
  295. prob_s = F.log_softmax(student_logits / self.temperature, dim=-1)
  296. prob_t = F.log_softmax(teacher_logits / self.temperature, dim=-1)
  297. kd_loss = F.kl_div(prob_s, prob_t, reduction='batchmean', log_target=True) * (self.temperature ** 2)
  298. total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * kd_loss
  299. return {
  300. 'loss': total_loss,
  301. 'output': student_logits,
  302. 'task_loss': task_loss,
  303. 'kd_loss': kd_loss,
  304. }
  305. class FeatureDistillationTrainableModule(nn.Module):
  306. """Trainable module for feature distillation.
  307. Wraps student model and projection layer into a single module where all
  308. trainable forward operations happen inside forward(). This ensures proper
  309. DDP wrapping when the module is used with DistributedDataParallel.
  310. """
  311. def __init__(
  312. self,
  313. student_model: nn.Module,
  314. projection: Optional[nn.Module] = None,
  315. ):
  316. """ Create trainable module wrapper for feature distillation.
  317. Args:
  318. student_model: Student model to train
  319. projection: Optional projection layer (Linear layer or None)
  320. """
  321. super().__init__()
  322. self.student = student_model
  323. self.projection = projection
  324. def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  325. """Forward pass through student and projection.
  326. Args:
  327. input: Input tensor [B, C, H, W]
  328. Returns:
  329. Tuple of (student_logits, student_features) where features are
  330. optionally projected to match teacher dimension.
  331. """
  332. feature_map = self.student.forward_features(input)
  333. student_logits = self.student.forward_head(feature_map)
  334. student_features = self.student.forward_head(feature_map, pre_logits=True)
  335. if self.projection is not None:
  336. student_features = self.projection(student_features)
  337. return student_logits, student_features
  338. class FeatureDistillationTask(TrainingTask):
  339. """Feature-based knowledge distillation task.
  340. Performs distillation by matching student and teacher intermediate features
  341. (pooled pre-logits) using MSE loss. Automatically creates a projection layer
  342. if student and teacher feature dimensions differ.
  343. Loss weighting supports two modes:
  344. 1. Independent weights: loss = task_loss_weight * task_loss + distill_loss_weight * distill_loss
  345. 2. Complementary mode: loss = task_loss_weight * task_loss + (1 - task_loss_weight) * distill_loss
  346. (used when only task_loss_weight is specified)
  347. Args:
  348. student_model: Student model to train
  349. teacher_model: Teacher model - can be a model name string, nn.Module, or DistillationTeacher
  350. criterion: Task loss function (default: CrossEntropyLoss)
  351. teacher_pretrained_path: Path to teacher pretrained weights (used when teacher_model is a string)
  352. distill_loss_weight: Weight for distillation loss
  353. task_loss_weight: Weight for task loss
  354. student_feature_dim: Student pre-logits dimension (auto-detected if None)
  355. teacher_feature_dim: Teacher pre-logits dimension (auto-detected if None)
  356. device: Device for task tensors/buffers
  357. dtype: Dtype for task tensors/buffers
  358. verbose: Enable info logging
  359. Example:
  360. >>> # With model name string (num_classes/in_chans inferred from student)
  361. >>> task = FeatureDistillationTask(
  362. ... student_model=model, teacher_model='resnet50',
  363. ... criterion=nn.CrossEntropyLoss(),
  364. ... distill_loss_weight=5.0, task_loss_weight=1.0,
  365. ... device=torch.device('cuda'),
  366. ... )
  367. """
  368. def __init__(
  369. self,
  370. student_model: nn.Module,
  371. teacher_model: Union[str, nn.Module, DistillationTeacher],
  372. criterion: Optional[nn.Module] = None,
  373. teacher_pretrained_path: Optional[str] = None,
  374. distill_loss_weight: Optional[float] = None,
  375. task_loss_weight: Optional[float] = None,
  376. student_feature_dim: Optional[int] = None,
  377. teacher_feature_dim: Optional[int] = None,
  378. device: Optional[torch.device] = None,
  379. dtype: Optional[torch.dtype] = None,
  380. verbose: bool = True,
  381. ):
  382. super().__init__(device=device, dtype=dtype, verbose=verbose)
  383. # Resolve teacher to DistillationTeacher
  384. teacher = _resolve_teacher(
  385. teacher_model,
  386. student_model,
  387. teacher_pretrained_path,
  388. self.device,
  389. self.dtype,
  390. )
  391. self.teacher = teacher
  392. self.criterion = criterion if criterion is not None else nn.CrossEntropyLoss()
  393. # Determine weighting mode
  394. if distill_loss_weight is not None:
  395. # Mode 1: distill_weight specified - independent weights (task defaults to 1.0 if not set)
  396. self.distill_loss_weight = distill_loss_weight
  397. self.task_loss_weight = task_loss_weight if task_loss_weight is not None else 1.0
  398. if self.verbose:
  399. _logger.info(
  400. f"FeatureDistillationTask: Independent weights - "
  401. f"task_weight={self.task_loss_weight}, distill_weight={distill_loss_weight}"
  402. )
  403. elif task_loss_weight is not None:
  404. # Mode 2: only task_weight specified - complementary mode (distill = 1 - task)
  405. self.task_loss_weight = task_loss_weight
  406. self.distill_loss_weight = 1.0 - task_loss_weight
  407. if self.verbose:
  408. _logger.info(
  409. f"FeatureDistillationTask: Complementary mode - "
  410. f"task_weight={task_loss_weight}, distill_weight={self.distill_loss_weight}"
  411. )
  412. else:
  413. # Mode 3: neither specified - equal weights (both 1.0)
  414. self.distill_loss_weight = 1.0
  415. self.task_loss_weight = 1.0
  416. if self.verbose:
  417. _logger.info(
  418. f"FeatureDistillationTask: Default equal weights - "
  419. f"task_weight={self.task_loss_weight}, distill_weight={self.distill_loss_weight}"
  420. )
  421. # Auto-detect feature dimensions if not provided
  422. if student_feature_dim is None:
  423. student_feature_dim = self._detect_feature_dim(student_model)
  424. if teacher_feature_dim is None:
  425. teacher_feature_dim = self._detect_feature_dim(teacher.model)
  426. # Create projection layer if dimensions differ
  427. projection = None
  428. if student_feature_dim != teacher_feature_dim:
  429. if self.verbose:
  430. _logger.info(
  431. f"Creating projection layer: {student_feature_dim} -> {teacher_feature_dim}"
  432. )
  433. projection = nn.Linear(student_feature_dim, teacher_feature_dim, device=self.device, dtype=self.dtype)
  434. else:
  435. if self.verbose:
  436. _logger.info("Feature dimensions match, no projection needed")
  437. self.trainable_module = FeatureDistillationTrainableModule(student_model, projection)
  438. # Register student normalization values
  439. student_unwrapped = unwrap_model(student_model)
  440. student_mean = torch.tensor(
  441. student_unwrapped.pretrained_cfg['mean'],
  442. device=self.device,
  443. dtype=self.dtype,
  444. ).view(1, -1, 1, 1)
  445. student_std = torch.tensor(
  446. student_unwrapped.pretrained_cfg['std'],
  447. device=self.device,
  448. dtype=self.dtype,
  449. ).view(1, -1, 1, 1)
  450. self.register_buffer('student_mean', student_mean, persistent=False)
  451. self.register_buffer('student_std', student_std, persistent=False)
  452. if self.verbose:
  453. _logger.info(
  454. f"FeatureDistillationTask: "
  455. f"student_dim={student_feature_dim}, teacher_dim={teacher_feature_dim}"
  456. )
  457. @staticmethod
  458. def _detect_feature_dim(model: nn.Module) -> int:
  459. """Auto-detect feature dimension from model."""
  460. model = unwrap_model(model)
  461. if hasattr(model, 'head_hidden_size'):
  462. return model.head_hidden_size
  463. elif hasattr(model, 'num_features'):
  464. return model.num_features
  465. else:
  466. raise ValueError(
  467. "Cannot auto-detect feature dimension. Model must have "
  468. "'head_hidden_size' or 'num_features' attribute, or you must "
  469. "specify student_feature_dim and teacher_feature_dim explicitly."
  470. )
  471. def prepare_distributed(
  472. self,
  473. device_ids: Optional[list] = None,
  474. **ddp_kwargs,
  475. ) -> 'FeatureDistillationTask':
  476. """Prepare task for distributed training.
  477. Wraps the trainable module (student + projection) in DistributedDataParallel
  478. (DDP) while leaving the frozen teacher model unwrapped.
  479. Args:
  480. device_ids: List of device IDs for DDP (e.g., [local_rank])
  481. **ddp_kwargs: Additional arguments passed to DistributedDataParallel
  482. Returns:
  483. self (for method chaining)
  484. """
  485. from torch.nn.parallel import DistributedDataParallel as DDP
  486. for param in self.teacher.parameters():
  487. param.requires_grad = False
  488. self.trainable_module = DDP(self.trainable_module, device_ids=device_ids, **ddp_kwargs)
  489. return self
  490. def forward(
  491. self,
  492. input: torch.Tensor,
  493. target: torch.Tensor,
  494. ) -> Dict[str, torch.Tensor]:
  495. """Forward pass with feature distillation.
  496. Args:
  497. input: Input tensor [B, C, H, W]
  498. target: Target labels [B]
  499. Returns:
  500. Dictionary containing:
  501. - 'loss': Combined training loss (task + distillation)
  502. - 'output': Student logits (for metrics)
  503. - 'task_loss': Classification loss component
  504. - 'kd_loss': Feature distillation loss component
  505. """
  506. student_logits, student_features = self.trainable_module(input)
  507. task_loss = self.criterion(student_logits, target)
  508. with torch.no_grad():
  509. input_kd = self.teacher.normalize_input(input, self.student_mean, self.student_std)
  510. teacher_features = self.teacher(input_kd.detach(), return_features=True)
  511. kd_loss = F.mse_loss(student_features, teacher_features)
  512. total_loss = self.task_loss_weight * task_loss + self.distill_loss_weight * kd_loss
  513. return {
  514. 'loss': total_loss,
  515. 'output': student_logits,
  516. 'task_loss': task_loss,
  517. 'kd_loss': kd_loss,
  518. }