backbone_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # Copyright 2026 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Collection of utils to be used by backbones and their components."""
  15. import enum
  16. import functools
  17. import inspect
  18. from huggingface_hub import repo_exists
  19. from .utils import logging
  20. from .utils.output_capturing import maybe_install_capturing_hooks
  21. logger = logging.get_logger(__name__)
  22. class BackboneType(enum.Enum):
  23. TIMM = "timm"
  24. TRANSFORMERS = "transformers"
  25. class BackboneConfigMixin:
  26. """
  27. A Mixin to support handling the `out_features` and `out_indices` attributes for the backbone configurations.
  28. """
  29. def set_output_features_output_indices(
  30. self,
  31. out_features: list | None,
  32. out_indices: list | None,
  33. ):
  34. """
  35. Sets output indices and features to new values and aligns them with the given `stage_names`.
  36. If one of the inputs is not given, find the corresponding `out_features` or `out_indices`
  37. for the given `stage_names`.
  38. Args:
  39. out_features (`list[str]`, *optional*):
  40. The names of the features for the backbone to output. Defaults to `config._out_features` if not provided.
  41. out_indices (`list[int]` or `tuple[int]`, *optional*):
  42. The indices of the features for the backbone to output. Defaults to `config._out_indices` if not provided.
  43. """
  44. self._out_features = out_features
  45. self._out_indices = list(out_indices) if isinstance(out_indices, tuple) else out_indices
  46. # First verify that the out_features and out_indices are valid
  47. self.verify_out_features_out_indices()
  48. # Align output features with indices
  49. out_features, out_indices = self._out_features, self._out_indices
  50. if out_indices is None and out_features is None:
  51. out_indices = [len(self.stage_names) - 1]
  52. out_features = [self.stage_names[-1]]
  53. elif out_indices is None and out_features is not None:
  54. out_indices = [self.stage_names.index(layer) for layer in out_features]
  55. elif out_features is None and out_indices is not None:
  56. out_features = [self.stage_names[idx] for idx in out_indices]
  57. # Update values and verify that the aligned out_features and out_indices are valid
  58. self._out_features, self._out_indices = out_features, out_indices
  59. self.verify_out_features_out_indices()
  60. def verify_out_features_out_indices(self):
  61. """
  62. Verify that out_indices and out_features are valid for the given stage_names.
  63. """
  64. if self.stage_names is None:
  65. raise ValueError("Stage_names must be set for transformers backbones")
  66. if self._out_features is not None:
  67. if not isinstance(self._out_features, (list,)):
  68. raise ValueError(f"out_features must be a list got {type(self._out_features)}")
  69. if any(feat not in self.stage_names for feat in self._out_features):
  70. raise ValueError(
  71. f"out_features must be a subset of stage_names: {self.stage_names} got {self._out_features}"
  72. )
  73. if len(self._out_features) != len(set(self._out_features)):
  74. raise ValueError(f"out_features must not contain any duplicates, got {self._out_features}")
  75. if self._out_features != (
  76. sorted_feats := [feat for feat in self.stage_names if feat in self._out_features]
  77. ):
  78. raise ValueError(
  79. f"out_features must be in the same order as stage_names, expected {sorted_feats} got {self._out_features}"
  80. )
  81. if self._out_indices is not None:
  82. if not isinstance(self._out_indices, list):
  83. raise ValueError(f"out_indices must be a list, got {type(self._out_indices)}")
  84. # Convert negative indices to their positive equivalent: [-1,] -> [len(stage_names) - 1,]
  85. positive_indices = tuple(idx % len(self.stage_names) if idx < 0 else idx for idx in self._out_indices)
  86. if any(idx for idx in positive_indices if idx not in range(len(self.stage_names))):
  87. raise ValueError(
  88. f"out_indices must be valid indices for stage_names {self.stage_names}, got {self._out_indices}"
  89. )
  90. if len(positive_indices) != len(set(positive_indices)):
  91. msg = f"out_indices must not contain any duplicates, got {self._out_indices}"
  92. msg += f"(equivalent to {positive_indices}))" if positive_indices != self._out_indices else ""
  93. raise ValueError(msg)
  94. if positive_indices != tuple(sorted(positive_indices)):
  95. sorted_negative = [
  96. idx for _, idx in sorted(zip(positive_indices, self._out_indices), key=lambda x: x[0])
  97. ]
  98. raise ValueError(
  99. f"out_indices must be in the same order as stage_names, expected {sorted_negative} got {self._out_indices}"
  100. )
  101. if self._out_features is not None and self._out_indices is not None:
  102. if len(self._out_features) != len(self._out_indices):
  103. raise ValueError("out_features and out_indices should have the same length if both are set")
  104. if self._out_features != [self.stage_names[idx] for idx in self._out_indices]:
  105. raise ValueError("out_features and out_indices should correspond to the same stages if both are set")
  106. @property
  107. def out_features(self):
  108. return self._out_features
  109. @out_features.setter
  110. def out_features(self, out_features: list[str]):
  111. """
  112. Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
  113. """
  114. self.set_output_features_output_indices(out_features=out_features, out_indices=None)
  115. @property
  116. def out_indices(self):
  117. return self._out_indices
  118. @out_indices.setter
  119. def out_indices(self, out_indices: tuple[int, ...] | list[int]):
  120. """
  121. Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
  122. """
  123. out_indices = list(out_indices) if out_indices is not None else out_indices
  124. self.set_output_features_output_indices(out_features=None, out_indices=out_indices)
  125. def to_dict(self):
  126. """
  127. Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PreTrainedConfig` to
  128. include the `out_features` and `out_indices` attributes.
  129. """
  130. output = super().to_dict()
  131. output["out_features"] = output.pop("_out_features", None)
  132. output["out_indices"] = output.pop("_out_indices", None)
  133. return output
  134. def filter_output_hidden_states(forward_function):
  135. """
  136. Wrapper to filer out `hidden_states` as backbones tend to always use them to get their feature maps, i.e.
  137. they also always output `hidden_states`. This controls for user-defined behavior again.
  138. NOTE: We assume a `can_return_tuple` decorator to be applied before so that we always expect a dict like
  139. object to remove the hidden states.
  140. """
  141. @functools.wraps(forward_function)
  142. def wrapper(self, *args, **kwargs):
  143. output_hidden_states = kwargs.get("output_hidden_states", getattr(self.config, "output_hidden_states", False))
  144. output = forward_function(self, *args, **kwargs)
  145. if not output_hidden_states:
  146. filtered_output_data = {k: v for k, v in output.items() if k not in ("hidden_states")}
  147. output = type(output)(**filtered_output_data)
  148. return output
  149. return wrapper
  150. class BackboneMixin:
  151. backbone_type: BackboneType | None = None
  152. # Attribute to indicate if the backbone has attention and can return attention outputs.
  153. # Should be set to `False` for conv-based models to be able to run `forward_with_filtered_kwargs`
  154. has_attentions: bool = True
  155. def __init__(self, *args, **kwargs) -> None:
  156. """
  157. Method to initialize the backbone. This method is called by the constructor of the base class after the
  158. pretrained model weights have been loaded.
  159. """
  160. super().__init__(*args, **kwargs)
  161. timm_backbone = kwargs.pop("timm_backbone", None)
  162. if timm_backbone is not None:
  163. self.backbone_type = BackboneType.TIMM
  164. else:
  165. self.backbone_type = BackboneType.TRANSFORMERS
  166. if self.backbone_type == BackboneType.TIMM:
  167. self._init_timm_backbone(backbone=timm_backbone)
  168. elif self.backbone_type == BackboneType.TRANSFORMERS:
  169. self._init_transformers_backbone()
  170. else:
  171. raise ValueError(f"backbone_type {self.backbone_type} not supported.")
  172. def post_init(self):
  173. """
  174. Override `post_init` to always install capturing hooks, as backbone will ALWAYS capture outputs. We need to do
  175. it in `post_init`, as modules need to be already instantiated.
  176. It avoids some mixups with `torch.compile`, as the first hook installation will need/create a graph break,
  177. which can clash with external user call such as `model = torch.compile(model...)`.
  178. """
  179. # NOTE: Since this class is ALWAYS used as a Mixin with another PreTrainedModel class, this `super` call
  180. # will call the PreTrained's `post_init`
  181. super().post_init()
  182. maybe_install_capturing_hooks(self)
  183. def _init_timm_backbone(self, backbone) -> None:
  184. """
  185. Initialize the backbone model from timm. The backbone must already be loaded to backbone
  186. """
  187. out_features_from_config = getattr(self.config, "out_features", None)
  188. stage_names_from_config = getattr(self.config, "stage_names", None)
  189. # These will disagree with the defaults for the transformers models e.g. for resnet50
  190. # the transformer model has out_features = ['stem', 'stage1', 'stage2', 'stage3', 'stage4']
  191. # the timm model has out_features = ['act', 'layer1', 'layer2', 'layer3', 'layer4']
  192. self.stage_names = [stage["module"] for stage in backbone.feature_info.info]
  193. self.num_features = [stage["num_chs"] for stage in backbone.feature_info.info]
  194. out_indices = list(backbone.feature_info.out_indices)
  195. out_features = backbone.feature_info.module_name()
  196. if out_features_from_config is not None and out_features_from_config != out_features:
  197. raise ValueError(
  198. f"Config has `out_features` set to {out_features_from_config} which doesn't match `out_features` "
  199. "from backbone's feature_info. Please check if your checkpoint has correct out features/indices saved."
  200. )
  201. if stage_names_from_config is not None and stage_names_from_config != self.stage_names:
  202. raise ValueError(
  203. f"Config has `stage_names` set to {stage_names_from_config} which doesn't match `stage_names` "
  204. "from backbone's feature_info. Please check if your checkpoint has correct `stage_names` saved."
  205. )
  206. # We set, align and verify out indices, out features and stage names
  207. self.config.stage_names = self.stage_names
  208. self.config.set_output_features_output_indices(out_features, out_indices)
  209. def _init_transformers_backbone(self) -> None:
  210. self.stage_names = self.config.stage_names
  211. self.config.verify_out_features_out_indices()
  212. # Number of channels for each stage. This is set in the transformer backbone model init
  213. self.num_features = None
  214. @property
  215. def out_features(self):
  216. return self.config._out_features
  217. @out_features.setter
  218. def out_features(self, out_features: list[str]):
  219. """
  220. Set the out_features attribute. This will also update the out_indices attribute to match the new out_features.
  221. """
  222. self.config.out_features = out_features
  223. @property
  224. def out_indices(self):
  225. return self.config._out_indices
  226. @out_indices.setter
  227. def out_indices(self, out_indices: tuple[int] | list[int]):
  228. """
  229. Set the out_indices attribute. This will also update the out_features attribute to match the new out_indices.
  230. """
  231. self.config.out_indices = out_indices
  232. @property
  233. def out_feature_channels(self):
  234. # the current backbones will output the number of channels for each stage
  235. # even if that stage is not in the out_features list.
  236. return {stage: self.num_features[i] for i, stage in enumerate(self.stage_names)}
  237. @property
  238. def channels(self):
  239. return [self.out_feature_channels[name] for name in self.out_features]
  240. def forward_with_filtered_kwargs(self, *args, **kwargs):
  241. if not self.has_attentions:
  242. kwargs.pop("output_attentions", None)
  243. if self.backbone_type == BackboneType.TIMM:
  244. signature = dict(inspect.signature(self.forward).parameters)
  245. kwargs = {k: v for k, v in kwargs.items() if k in signature}
  246. return self(*args, **kwargs)
  247. def forward(
  248. self,
  249. pixel_values,
  250. output_hidden_states: bool | None = None,
  251. output_attentions: bool | None = None,
  252. return_dict: bool | None = None,
  253. ):
  254. raise NotImplementedError("This method should be implemented by the derived class.")
  255. def consolidate_backbone_kwargs_to_config(
  256. backbone_config,
  257. default_backbone: str | None = None,
  258. default_config_type: str | None = None,
  259. default_config_kwargs: dict | None = None,
  260. timm_default_kwargs: dict | None = None,
  261. **kwargs,
  262. ):
  263. # Lazy import to avoid circular import issues. Can be imported properly
  264. # after deleting ref to `BackboneMixin` in `utils/backbone_utils.py`
  265. from .configuration_utils import PreTrainedConfig
  266. from .models.auto import CONFIG_MAPPING
  267. use_timm_backbone = kwargs.pop("use_timm_backbone", True)
  268. backbone_kwargs = kwargs.pop("backbone_kwargs", {})
  269. backbone = kwargs.pop("backbone") if kwargs.get("backbone") is not None else default_backbone
  270. kwargs.pop("use_pretrained_backbone", None)
  271. # Init timm backbone with hardcoded values for BC. If everything is set to `None` and there is
  272. # a default timm config, we use it to init the backbone.
  273. if (
  274. timm_default_kwargs is not None
  275. and use_timm_backbone
  276. and backbone is not None
  277. and backbone_config is None
  278. and not backbone_kwargs
  279. ):
  280. backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **timm_default_kwargs)
  281. elif backbone is not None and backbone_config is None:
  282. if repo_exists(backbone):
  283. config_dict, _ = PreTrainedConfig.get_config_dict(backbone)
  284. config_class = CONFIG_MAPPING[config_dict["model_type"]]
  285. config_dict.update(backbone_kwargs)
  286. backbone_config = config_class(**config_dict)
  287. else:
  288. backbone_config = CONFIG_MAPPING["timm_backbone"](backbone=backbone, **backbone_kwargs)
  289. elif backbone_config is None and default_config_type is not None:
  290. logger.info(
  291. f"`backbone_config` is `None`. Initializing the config with the default `{default_config_type}` vision config."
  292. )
  293. default_config_kwargs = default_config_kwargs or {}
  294. backbone_config = CONFIG_MAPPING[default_config_type](**default_config_kwargs)
  295. elif isinstance(backbone_config, dict):
  296. backbone_model_type = backbone_config.get("model_type")
  297. config_class = CONFIG_MAPPING[backbone_model_type]
  298. backbone_config = config_class.from_dict(backbone_config)
  299. return backbone_config, kwargs
  300. def load_backbone(config):
  301. """
  302. Loads the backbone model from a config object.
  303. If the config is from the backbone model itself, then we return a backbone model with randomly initialized
  304. weights.
  305. If the config is from the parent model of the backbone model itself, then we load the pretrained backbone weights
  306. if specified.
  307. """
  308. from transformers import AutoBackbone
  309. backbone_config = getattr(config, "backbone_config", None)
  310. if backbone_config is None:
  311. backbone = AutoBackbone.from_config(config=config)
  312. else:
  313. backbone = AutoBackbone.from_config(config=backbone_config)
  314. return backbone