modeling_eomt.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/eomt/modular_eomt.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_eomt.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import collections.abc
  21. import math
  22. from collections.abc import Callable
  23. from dataclasses import dataclass
  24. import numpy as np
  25. import torch
  26. import torch.nn.functional as F
  27. from torch import Tensor, nn
  28. from ... import initialization as init
  29. from ...activations import ACT2FN
  30. from ...file_utils import ModelOutput, is_scipy_available, requires_backends
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import TransformersKwargs, auto_docstring, is_accelerate_available
  35. from ...utils.generic import merge_with_config_defaults
  36. from ...utils.output_capturing import capture_outputs
  37. from .configuration_eomt import EomtConfig
  38. if is_scipy_available():
  39. from scipy.optimize import linear_sum_assignment
  40. if is_accelerate_available():
  41. from accelerate import PartialState
  42. from accelerate.utils import reduce
  43. @dataclass
  44. @auto_docstring(
  45. custom_intro="""
  46. Class for outputs of [`EomtForUniversalSegmentationOutput`].
  47. This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or
  48. [`~EomtImageProcessor.post_process_instance_segmentation`] or
  49. [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
  50. [`~EomtImageProcessor] for details regarding usage.
  51. """
  52. )
  53. class EomtForUniversalSegmentationOutput(ModelOutput):
  54. r"""
  55. loss (`torch.Tensor`, *optional*):
  56. The computed loss, returned when labels are present.
  57. class_queries_logits (`torch.FloatTensor`):
  58. A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
  59. query. Note the `+ 1` is needed because we incorporate the null class.
  60. masks_queries_logits (`torch.FloatTensor`):
  61. A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
  62. query.
  63. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  64. Last hidden states (final feature map) of the last layer.
  65. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  66. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  67. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
  68. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  69. Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  70. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
  71. patch_offsets (`list[torch.Tensor]`, *optional*):
  72. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  73. """
  74. loss: torch.FloatTensor | None = None
  75. class_queries_logits: torch.FloatTensor | None = None
  76. masks_queries_logits: torch.FloatTensor | None = None
  77. last_hidden_state: torch.FloatTensor | None = None
  78. hidden_states: tuple[torch.FloatTensor] | None = None
  79. attentions: tuple[torch.FloatTensor] | None = None
  80. patch_offsets: list[torch.Tensor] | None = None
  81. # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
  82. def sample_point(
  83. input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
  84. ) -> torch.Tensor:
  85. """
  86. A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
  87. Args:
  88. input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
  89. A tensor that contains features map on a height * width grid
  90. point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
  91. 2)):
  92. A tensor that contains [0, 1] * [0, 1] normalized point coordinates
  93. add_dim (`bool`):
  94. boolean value to keep track of added dimension
  95. Returns:
  96. point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
  97. height_grid, width_grid):
  98. A tensor that contains features for points in `point_coordinates`.
  99. """
  100. if point_coordinates.dim() == 3:
  101. add_dim = True
  102. point_coordinates = point_coordinates.unsqueeze(2)
  103. # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
  104. point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
  105. if add_dim:
  106. point_features = point_features.squeeze(3)
  107. return point_features
  108. def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
  109. """
  110. A pair wise version of the dice loss, see `dice_loss` for usage.
  111. Args:
  112. inputs (`torch.Tensor`):
  113. A tensor representing a mask
  114. labels (`torch.Tensor`):
  115. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  116. (0 for the negative class and 1 for the positive class).
  117. Returns:
  118. `torch.Tensor`: The computed loss between each pairs.
  119. """
  120. inputs = inputs.sigmoid().flatten(1)
  121. numerator = 2 * torch.matmul(inputs, labels.T)
  122. # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
  123. denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
  124. loss = 1 - (numerator + 1) / (denominator + 1)
  125. return loss
  126. def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
  127. r"""
  128. A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
  129. Args:
  130. inputs (`torch.Tensor`):
  131. A tensor representing a mask.
  132. labels (`torch.Tensor`):
  133. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  134. (0 for the negative class and 1 for the positive class).
  135. Returns:
  136. loss (`torch.Tensor`): The computed loss between each pairs.
  137. """
  138. height_and_width = inputs.shape[1]
  139. criterion = nn.BCEWithLogitsLoss(reduction="none")
  140. cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
  141. cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
  142. loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
  143. loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
  144. loss = loss_pos + loss_neg
  145. return loss
  146. # Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/matcher.py
  147. class EomtHungarianMatcher(nn.Module):
  148. """This class computes an assignment between the labels and the predictions of the network.
  149. For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
  150. predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
  151. un-matched (and thus treated as non-objects).
  152. """
  153. def __init__(
  154. self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
  155. ):
  156. """Creates the matcher
  157. Params:
  158. cost_class (`float`, *optional*, defaults to 1.0):
  159. Relative weight of the classification error in the matching cost.
  160. cost_mask (`float`, *optional*, defaults to 1.0):
  161. This is the relative weight of the focal loss of the binary mask in the matching cost.
  162. cost_dice (`float`, *optional*, defaults to 1.0):
  163. This is the relative weight of the dice loss of the binary mask in the matching cost.
  164. num_points (`int`, *optional*, defaults to 12544):
  165. No. of points to sample on which the mask loss will be calculated. The same set of K points are
  166. uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
  167. matching.
  168. """
  169. super().__init__()
  170. if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
  171. raise ValueError("All costs can't be 0")
  172. self.num_points = num_points
  173. self.cost_class = cost_class
  174. self.cost_mask = cost_mask
  175. self.cost_dice = cost_dice
  176. @torch.no_grad()
  177. def forward(
  178. self,
  179. masks_queries_logits: torch.Tensor,
  180. class_queries_logits: torch.Tensor,
  181. mask_labels: torch.Tensor,
  182. class_labels: torch.Tensor,
  183. ) -> list[tuple[Tensor]]:
  184. """
  185. Params:
  186. masks_queries_logits (`torch.Tensor`):
  187. A tensor of dim `batch_size, num_queries, num_labels` with the classification logits.
  188. class_queries_logits (`torch.Tensor`):
  189. A tensor of dim `batch_size, num_queries, height, width` with the predicted masks.
  190. class_labels (`torch.Tensor`):
  191. A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the
  192. target) containing the class labels.
  193. mask_labels (`torch.Tensor`):
  194. A tensor of dim `num_target_boxes, height, width` containing the target masks.
  195. Returns:
  196. matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j)
  197. where:
  198. - index_i is the indices of the selected predictions (in order)
  199. - index_j is the indices of the corresponding selected labels (in order)
  200. For each batch element, it holds:
  201. len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
  202. """
  203. indices: list[tuple[np.array]] = []
  204. # iterate through batch size
  205. batch_size = masks_queries_logits.shape[0]
  206. for i in range(batch_size):
  207. pred_probs = class_queries_logits[i].softmax(-1)
  208. pred_mask = masks_queries_logits[i]
  209. # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted.
  210. cost_class = -pred_probs[:, class_labels[i]]
  211. target_mask = mask_labels[i].to(pred_mask)
  212. target_mask = target_mask[:, None]
  213. pred_mask = pred_mask[:, None]
  214. # Sample ground truth and predicted masks
  215. point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
  216. target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1)
  217. target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
  218. pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1)
  219. pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
  220. # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
  221. cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
  222. # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels)
  223. cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
  224. # final cost matrix
  225. cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
  226. # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
  227. cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
  228. cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
  229. cost_matrix = torch.nan_to_num(cost_matrix, 0)
  230. # do the assignment using the hungarian algorithm in scipy
  231. assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
  232. indices.append(assigned_indices)
  233. # It could be stacked in one tensor
  234. matched_indices = [
  235. (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
  236. ]
  237. return matched_indices
  238. def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
  239. r"""
  240. Compute the DICE loss, similar to generalized IOU for masks as follows:
  241. $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
  242. In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
  243. $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
  244. Args:
  245. inputs (`torch.Tensor`):
  246. A tensor representing a mask.
  247. labels (`torch.Tensor`):
  248. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  249. (0 for the negative class and 1 for the positive class).
  250. num_masks (`int`):
  251. The number of masks present in the current batch, used for normalization.
  252. Returns:
  253. `torch.Tensor`: The computed loss.
  254. """
  255. probs = inputs.sigmoid().flatten(1)
  256. numerator = 2 * (probs * labels).sum(-1)
  257. denominator = probs.sum(-1) + labels.sum(-1)
  258. loss = 1 - (numerator + 1) / (denominator + 1)
  259. loss = loss.sum() / num_masks
  260. return loss
  261. def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
  262. r"""
  263. Args:
  264. inputs (`torch.Tensor`):
  265. A float tensor of arbitrary shape.
  266. labels (`torch.Tensor`):
  267. A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
  268. (0 for the negative class and 1 for the positive class).
  269. Returns:
  270. loss (`torch.Tensor`): The computed loss.
  271. """
  272. criterion = nn.BCEWithLogitsLoss(reduction="none")
  273. cross_entropy_loss = criterion(inputs, labels)
  274. loss = cross_entropy_loss.mean(1).sum() / num_masks
  275. return loss
  276. # Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/criterion.py
  277. class EomtLoss(nn.Module):
  278. def __init__(self, config: EomtConfig, weight_dict: dict[str, float]):
  279. """
  280. The Eomt Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
  281. compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
  282. of matched ground-truth / prediction (supervise class and mask)
  283. Args:
  284. config (`EomtConfig`):
  285. The configuration for Eomt model also containing loss calculation specific parameters.
  286. weight_dict (`dict[str, float]`):
  287. A dictionary of weights to be applied to the different losses.
  288. """
  289. super().__init__()
  290. requires_backends(self, ["scipy"])
  291. self.num_labels = config.num_labels
  292. self.weight_dict = weight_dict
  293. # Weight to apply to the null class
  294. self.eos_coef = config.no_object_weight
  295. empty_weight = torch.ones(self.num_labels + 1)
  296. empty_weight[-1] = self.eos_coef
  297. self.register_buffer("empty_weight", empty_weight)
  298. # pointwise mask loss parameters
  299. self.num_points = config.train_num_points
  300. self.oversample_ratio = config.oversample_ratio
  301. self.importance_sample_ratio = config.importance_sample_ratio
  302. self.matcher = EomtHungarianMatcher(
  303. cost_class=config.class_weight,
  304. cost_dice=config.dice_weight,
  305. cost_mask=config.mask_weight,
  306. num_points=self.num_points,
  307. )
  308. def _max_by_axis(self, sizes: list[list[int]]) -> list[int]:
  309. maxes = sizes[0]
  310. for sublist in sizes[1:]:
  311. for index, item in enumerate(sublist):
  312. maxes[index] = max(maxes[index], item)
  313. return maxes
  314. # Adapted from nested_tensor_from_tensor_list() in original implementation
  315. def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
  316. # get the maximum size in the batch
  317. max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
  318. # compute final size
  319. batch_shape = [len(tensors)] + max_size
  320. batch_size, _, height, width = batch_shape
  321. dtype = tensors[0].dtype
  322. device = tensors[0].device
  323. padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
  324. padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
  325. # pad the tensors to the size of the biggest one
  326. for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
  327. padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
  328. padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
  329. return padded_tensors, padding_masks
  330. def loss_labels(
  331. self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array]
  332. ) -> dict[str, Tensor]:
  333. """Compute the losses related to the labels using cross entropy.
  334. Args:
  335. class_queries_logits (`torch.Tensor`):
  336. A tensor of shape `batch_size, num_queries, num_labels`
  337. class_labels (`list[torch.Tensor]`):
  338. List of class labels of shape `(labels)`.
  339. indices (`tuple[np.array])`:
  340. The indices computed by the Hungarian matcher.
  341. Returns:
  342. `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
  343. - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
  344. """
  345. pred_logits = class_queries_logits
  346. batch_size, num_queries, _ = pred_logits.shape
  347. criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
  348. idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries)
  349. target_classes_o = torch.cat(
  350. [target[j] for target, (_, j) in zip(class_labels, indices)]
  351. ) # shape of (batch_size, num_queries)
  352. target_classes = torch.full(
  353. (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
  354. )
  355. target_classes[idx] = target_classes_o
  356. # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
  357. pred_logits_transposed = pred_logits.transpose(1, 2)
  358. loss_ce = criterion(pred_logits_transposed, target_classes)
  359. losses = {"loss_cross_entropy": loss_ce}
  360. return losses
  361. def loss_masks(
  362. self,
  363. masks_queries_logits: torch.Tensor,
  364. mask_labels: list[torch.Tensor],
  365. indices: tuple[np.array],
  366. num_masks: int,
  367. ) -> dict[str, torch.Tensor]:
  368. """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
  369. Args:
  370. masks_queries_logits (`torch.Tensor`):
  371. A tensor of shape `(batch_size, num_queries, height, width)`.
  372. mask_labels (`torch.Tensor`):
  373. List of mask labels of shape `(labels, height, width)`.
  374. indices (`tuple[np.array])`:
  375. The indices computed by the Hungarian matcher.
  376. num_masks (`int)`:
  377. The number of masks, used for normalization.
  378. Returns:
  379. losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
  380. - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
  381. masks.
  382. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
  383. masks.
  384. """
  385. src_idx = self._get_predictions_permutation_indices(indices)
  386. tgt_idx = self._get_targets_permutation_indices(indices)
  387. # shape (batch_size * num_queries, height, width)
  388. pred_masks = masks_queries_logits[src_idx]
  389. # shape (batch_size, num_queries, height, width)
  390. # pad all and stack the targets to the num_labels dimension
  391. target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
  392. target_masks = target_masks[tgt_idx]
  393. # No need to upsample predictions as we are using normalized coordinates
  394. pred_masks = pred_masks[:, None]
  395. target_masks = target_masks[:, None]
  396. # Sample point coordinates
  397. with torch.no_grad():
  398. point_coordinates = self.sample_points_using_uncertainty(
  399. pred_masks,
  400. lambda logits: self.calculate_uncertainty(logits),
  401. self.num_points,
  402. self.oversample_ratio,
  403. self.importance_sample_ratio,
  404. )
  405. point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
  406. point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
  407. losses = {
  408. "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
  409. "loss_dice": dice_loss(point_logits, point_labels, num_masks),
  410. }
  411. del pred_masks
  412. del target_masks
  413. return losses
  414. def _get_predictions_permutation_indices(self, indices):
  415. # Permute predictions following indices
  416. batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
  417. predictions_indices = torch.cat([src for (src, _) in indices])
  418. return batch_indices, predictions_indices
  419. def _get_targets_permutation_indices(self, indices):
  420. # Permute labels following indices
  421. batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  422. target_indices = torch.cat([tgt for (_, tgt) in indices])
  423. return batch_indices, target_indices
  424. def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
  425. """
  426. In Eomt paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
  427. for the foreground class in `classes`.
  428. Args:
  429. logits (`torch.Tensor`):
  430. A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
  431. the number of foreground classes. The values are logits.
  432. Returns:
  433. scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
  434. uncertain locations having the highest uncertainty score.
  435. """
  436. uncertainty_scores = -(torch.abs(logits))
  437. return uncertainty_scores
  438. def sample_points_using_uncertainty(
  439. self,
  440. logits: torch.Tensor,
  441. uncertainty_function,
  442. num_points: int,
  443. oversample_ratio: int,
  444. importance_sample_ratio: float,
  445. ) -> torch.Tensor:
  446. """
  447. This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
  448. uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
  449. prediction as input.
  450. Args:
  451. logits (`float`):
  452. Logit predictions for P points.
  453. uncertainty_function:
  454. A function that takes logit predictions for P points and returns their uncertainties.
  455. num_points (`int`):
  456. The number of points P to sample.
  457. oversample_ratio (`int`):
  458. Oversampling parameter.
  459. importance_sample_ratio (`float`):
  460. Ratio of points that are sampled via importance sampling.
  461. Returns:
  462. point_coordinates (`torch.Tensor`):
  463. Coordinates for P sampled points.
  464. """
  465. num_boxes = logits.shape[0]
  466. num_points_sampled = int(num_points * oversample_ratio)
  467. # Get random point coordinates
  468. point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
  469. # Get sampled prediction value for the point coordinates
  470. point_logits = sample_point(logits, point_coordinates, align_corners=False)
  471. # Calculate the uncertainties based on the sampled prediction values of the points
  472. point_uncertainties = uncertainty_function(point_logits)
  473. num_uncertain_points = int(importance_sample_ratio * num_points)
  474. num_random_points = num_points - num_uncertain_points
  475. idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
  476. shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
  477. idx += shift[:, None]
  478. point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
  479. if num_random_points > 0:
  480. point_coordinates = torch.cat(
  481. [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
  482. dim=1,
  483. )
  484. return point_coordinates
  485. def forward(
  486. self,
  487. masks_queries_logits: torch.Tensor,
  488. class_queries_logits: torch.Tensor,
  489. mask_labels: list[torch.Tensor],
  490. class_labels: list[torch.Tensor],
  491. auxiliary_predictions: dict[str, torch.Tensor] | None = None,
  492. ) -> dict[str, torch.Tensor]:
  493. """
  494. This performs the loss computation.
  495. Args:
  496. masks_queries_logits (`torch.Tensor`):
  497. A tensor of shape `(batch_size, num_queries, height, width)`.
  498. class_queries_logits (`torch.Tensor`):
  499. A tensor of shape `(batch_size, num_queries, num_labels)`.
  500. mask_labels (`torch.Tensor`):
  501. List of mask labels of shape `(labels, height, width)`.
  502. class_labels (`list[torch.Tensor]`):
  503. List of class labels of shape `(labels)`.
  504. auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*):
  505. if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], then it contains the logits from
  506. the inner layers of the EomtMaskedAttentionDecoder.
  507. Returns:
  508. losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
  509. - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
  510. - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
  511. masks.
  512. - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
  513. masks.
  514. if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], the dictionary contains additional
  515. losses for each auxiliary predictions.
  516. """
  517. # retrieve the matching between the outputs of the last layer and the labels
  518. indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
  519. # compute the average number of target masks for normalization purposes
  520. num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
  521. # get all the losses
  522. losses: dict[str, Tensor] = {
  523. **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
  524. **self.loss_labels(class_queries_logits, class_labels, indices),
  525. }
  526. # in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  527. if auxiliary_predictions is not None:
  528. for idx, aux_outputs in enumerate(auxiliary_predictions):
  529. masks_queries_logits = aux_outputs["masks_queries_logits"]
  530. class_queries_logits = aux_outputs["class_queries_logits"]
  531. loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
  532. loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
  533. losses.update(loss_dict)
  534. return losses
  535. def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
  536. """
  537. Computes the average number of target masks across the batch, for normalization purposes.
  538. """
  539. num_masks = sum(len(classes) for classes in class_labels)
  540. num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
  541. world_size = 1
  542. if is_accelerate_available():
  543. if PartialState._shared_state != {}:
  544. num_masks = reduce(num_masks)
  545. world_size = PartialState().num_processes
  546. num_masks = torch.clamp(num_masks / world_size, min=1)
  547. return num_masks
  548. class EomtPatchEmbeddings(nn.Module):
  549. """
  550. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  551. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  552. Transformer.
  553. """
  554. def __init__(self, config):
  555. super().__init__()
  556. image_size, patch_size = config.image_size, config.patch_size
  557. num_channels, hidden_size = config.num_channels, config.hidden_size
  558. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  559. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  560. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  561. self.image_size = image_size
  562. self.patch_size = patch_size
  563. self.num_channels = num_channels
  564. self.num_patches = num_patches
  565. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  566. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  567. num_channels = pixel_values.shape[1]
  568. if num_channels != self.num_channels:
  569. raise ValueError(
  570. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  571. f" Expected {self.num_channels} but got {num_channels}."
  572. )
  573. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  574. return embeddings
  575. class EomtEmbeddings(nn.Module):
  576. """
  577. Construct the CLS token, mask token, position and patch embeddings.
  578. """
  579. def __init__(self, config: EomtConfig) -> None:
  580. super().__init__()
  581. self.config = config
  582. self.patch_size = config.patch_size
  583. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  584. self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
  585. self.patch_embeddings = EomtPatchEmbeddings(config)
  586. num_patches = self.patch_embeddings.num_patches
  587. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  588. self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
  589. self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
  590. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  591. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  592. batch_size, _, _, _ = pixel_values.shape
  593. target_dtype = self.patch_embeddings.projection.weight.dtype
  594. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  595. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  596. register_tokens = self.register_tokens.expand(batch_size, -1, -1)
  597. embeddings = embeddings + self.position_embeddings(self.position_ids)
  598. embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
  599. embeddings = self.dropout(embeddings)
  600. return embeddings
  601. def eager_attention_forward(
  602. module: nn.Module,
  603. query: torch.Tensor,
  604. key: torch.Tensor,
  605. value: torch.Tensor,
  606. attention_mask: torch.Tensor | None,
  607. scaling: float,
  608. dropout: float = 0.0,
  609. **kwargs,
  610. ):
  611. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  612. if attention_mask is not None:
  613. attn_weights = attn_weights + attention_mask
  614. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  615. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  616. attn_output = torch.matmul(attn_weights, value)
  617. attn_output = attn_output.transpose(1, 2).contiguous()
  618. return attn_output, attn_weights
  619. class EomtAttention(nn.Module):
  620. """Multi-headed attention from 'Attention Is All You Need' paper"""
  621. def __init__(self, config):
  622. super().__init__()
  623. self.config = config
  624. self.embed_dim = config.hidden_size
  625. self.num_heads = config.num_attention_heads
  626. self.head_dim = self.embed_dim // self.num_heads
  627. if self.head_dim * self.num_heads != self.embed_dim:
  628. raise ValueError(
  629. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  630. f" {self.num_heads})."
  631. )
  632. self.scale = self.head_dim**-0.5
  633. self.dropout = config.attention_dropout
  634. self.is_causal = False
  635. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  636. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  637. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  638. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  639. def forward(
  640. self,
  641. hidden_states: torch.Tensor,
  642. attention_mask: torch.Tensor | None = None,
  643. **kwargs,
  644. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  645. """Input shape: Batch x Time x Channel"""
  646. input_shape = hidden_states.shape[:-1]
  647. hidden_shape = (*input_shape, -1, self.head_dim)
  648. queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  649. keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  650. values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  651. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  652. self.config._attn_implementation, eager_attention_forward
  653. )
  654. attn_output, attn_weights = attention_interface(
  655. self,
  656. queries,
  657. keys,
  658. values,
  659. attention_mask,
  660. is_causal=self.is_causal,
  661. scaling=self.scale,
  662. dropout=0.0 if not self.training else self.dropout,
  663. )
  664. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  665. attn_output = self.out_proj(attn_output)
  666. return attn_output, attn_weights
  667. class EomtLayerScale(nn.Module):
  668. def __init__(self, config) -> None:
  669. super().__init__()
  670. self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
  671. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  672. return hidden_state * self.lambda1
  673. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  674. """
  675. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  676. """
  677. if drop_prob == 0.0 or not training:
  678. return input
  679. keep_prob = 1 - drop_prob
  680. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  681. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  682. random_tensor.floor_() # binarize
  683. output = input.div(keep_prob) * random_tensor
  684. return output
  685. class EomtDropPath(nn.Module):
  686. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  687. def __init__(self, drop_prob: float | None = None) -> None:
  688. super().__init__()
  689. self.drop_prob = drop_prob
  690. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  691. return drop_path(hidden_states, self.drop_prob, self.training)
  692. def extra_repr(self) -> str:
  693. return f"p={self.drop_prob}"
  694. class EomtMLP(nn.Module):
  695. def __init__(self, config) -> None:
  696. super().__init__()
  697. in_features = out_features = config.hidden_size
  698. hidden_features = int(config.hidden_size * config.mlp_ratio)
  699. self.fc1 = nn.Linear(in_features, hidden_features, bias=True)
  700. if isinstance(config.hidden_act, str):
  701. self.activation = ACT2FN[config.hidden_act]
  702. else:
  703. self.activation = config.hidden_act
  704. self.fc2 = nn.Linear(hidden_features, out_features, bias=True)
  705. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  706. hidden_state = self.fc1(hidden_state)
  707. hidden_state = self.activation(hidden_state)
  708. hidden_state = self.fc2(hidden_state)
  709. return hidden_state
  710. class EomtSwiGLUFFN(nn.Module):
  711. def __init__(self, config) -> None:
  712. super().__init__()
  713. in_features = out_features = config.hidden_size
  714. hidden_features = int(config.hidden_size * config.mlp_ratio)
  715. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  716. self.weights_in = nn.Linear(in_features, 2 * hidden_features, bias=True)
  717. self.weights_out = nn.Linear(hidden_features, out_features, bias=True)
  718. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  719. hidden_state = self.weights_in(hidden_state)
  720. x1, x2 = hidden_state.chunk(2, dim=-1)
  721. hidden = nn.functional.silu(x1) * x2
  722. return self.weights_out(hidden)
  723. class EomtLayer(GradientCheckpointingLayer):
  724. """This corresponds to the Block class in the original implementation."""
  725. def __init__(self, config: EomtConfig) -> None:
  726. super().__init__()
  727. self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  728. self.attention = EomtAttention(config)
  729. self.layer_scale1 = EomtLayerScale(config)
  730. self.drop_path = EomtDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
  731. self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  732. if config.use_swiglu_ffn:
  733. self.mlp = EomtSwiGLUFFN(config)
  734. else:
  735. self.mlp = EomtMLP(config)
  736. self.layer_scale2 = EomtLayerScale(config)
  737. def forward(
  738. self,
  739. hidden_states: torch.Tensor,
  740. attention_mask: torch.Tensor | None = None,
  741. ) -> torch.Tensor:
  742. hidden_states_norm = self.norm1(hidden_states)
  743. self_attention_output, _ = self.attention(hidden_states_norm, attention_mask)
  744. self_attention_output = self.layer_scale1(self_attention_output)
  745. # first residual connection
  746. hidden_states = self.drop_path(self_attention_output) + hidden_states
  747. # in Eomt, layernorm is also applied after self-attention
  748. layer_output = self.norm2(hidden_states)
  749. layer_output = self.mlp(layer_output)
  750. layer_output = self.layer_scale2(layer_output)
  751. # second residual connection
  752. layer_output = self.drop_path(layer_output) + hidden_states
  753. return layer_output
  754. class EomtLayerNorm2d(nn.LayerNorm):
  755. def __init__(self, num_channels, eps=1e-6, affine=True):
  756. super().__init__(num_channels, eps=eps, elementwise_affine=affine)
  757. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  758. hidden_state = hidden_state.permute(0, 2, 3, 1)
  759. hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
  760. hidden_state = hidden_state.permute(0, 3, 1, 2)
  761. return hidden_state
  762. class EomtScaleLayer(nn.Module):
  763. def __init__(self, config: EomtConfig):
  764. super().__init__()
  765. hidden_size = config.hidden_size
  766. self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
  767. self.activation = ACT2FN[config.hidden_act]
  768. self.conv2 = nn.Conv2d(
  769. hidden_size,
  770. hidden_size,
  771. kernel_size=3,
  772. padding=1,
  773. groups=hidden_size,
  774. bias=False,
  775. )
  776. self.layernorm2d = EomtLayerNorm2d(hidden_size)
  777. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  778. hidden_states = self.conv1(hidden_states)
  779. hidden_states = self.activation(hidden_states)
  780. hidden_states = self.conv2(hidden_states)
  781. hidden_states = self.layernorm2d(hidden_states)
  782. return hidden_states
  783. class EomtScaleBlock(nn.Module):
  784. def __init__(self, config: EomtConfig):
  785. super().__init__()
  786. self.num_blocks = config.num_upscale_blocks
  787. self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)])
  788. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  789. for block in self.block:
  790. hidden_states = block(hidden_states)
  791. return hidden_states
  792. class EomtMaskHead(nn.Module):
  793. def __init__(self, config: EomtConfig):
  794. super().__init__()
  795. hidden_size = config.hidden_size
  796. self.fc1 = nn.Linear(hidden_size, hidden_size)
  797. self.fc2 = nn.Linear(hidden_size, hidden_size)
  798. self.fc3 = nn.Linear(hidden_size, hidden_size)
  799. self.activation = ACT2FN[config.hidden_act]
  800. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  801. hidden_states = self.activation(self.fc1(hidden_states))
  802. hidden_states = self.activation(self.fc2(hidden_states))
  803. hidden_states = self.fc3(hidden_states)
  804. return hidden_states
  805. @auto_docstring
  806. class EomtPreTrainedModel(PreTrainedModel):
  807. """
  808. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  809. models.
  810. """
  811. config: EomtConfig
  812. base_model_prefix = "eomt"
  813. main_input_name = "pixel_values"
  814. input_modalities = ("image",)
  815. supports_gradient_checkpointing = False
  816. _no_split_modules = ["EomtLayer"]
  817. _supports_sdpa = True
  818. _can_record_outputs = {
  819. "hidden_states": EomtLayer,
  820. "attentions": EomtAttention,
  821. }
  822. @torch.no_grad()
  823. def _init_weights(self, module: nn.Module) -> None:
  824. std = self.config.initializer_range
  825. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  826. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  827. if module.bias is not None:
  828. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  829. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  830. init.uniform_(module.bias, -bound, bound)
  831. elif isinstance(module, nn.LayerNorm):
  832. init.ones_(module.weight)
  833. init.zeros_(module.bias)
  834. elif isinstance(module, nn.Embedding):
  835. init.normal_(module.weight, mean=0.0, std=1)
  836. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  837. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  838. init.zeros_(module.weight[module.padding_idx])
  839. elif isinstance(module, EomtLayerScale):
  840. if hasattr(module, "lambda1"):
  841. init.constant_(module.lambda1, self.config.layerscale_value)
  842. elif isinstance(module, EomtEmbeddings):
  843. init.trunc_normal_(module.cls_token, mean=0.0, std=std)
  844. init.zeros_(module.register_tokens)
  845. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  846. elif isinstance(module, EomtLoss):
  847. empty_weight = torch.ones(module.num_labels + 1)
  848. empty_weight[-1] = module.eos_coef
  849. init.copy_(module.empty_weight, empty_weight)
  850. elif isinstance(module, EomtForUniversalSegmentation):
  851. init.ones_(module.attn_mask_probs)
  852. @auto_docstring(
  853. custom_intro="""
  854. The EoMT Model with head on top for instance/semantic/panoptic segmentation.
  855. """
  856. )
  857. class EomtForUniversalSegmentation(EomtPreTrainedModel):
  858. main_input_name = "pixel_values"
  859. def __init__(self, config: EomtConfig):
  860. super().__init__(config)
  861. self.config = config
  862. self.num_hidden_layers = config.num_hidden_layers
  863. self.embeddings = EomtEmbeddings(config)
  864. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  865. self.query = nn.Embedding(config.num_queries, config.hidden_size)
  866. self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)])
  867. self.upscale_block = EomtScaleBlock(config)
  868. self.mask_head = EomtMaskHead(config)
  869. self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
  870. self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  871. self.weight_dict: dict[str, float] = {
  872. "loss_cross_entropy": config.class_weight,
  873. "loss_mask": config.mask_weight,
  874. "loss_dice": config.dice_weight,
  875. }
  876. self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict)
  877. self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
  878. self.post_init()
  879. def get_loss_dict(
  880. self,
  881. masks_queries_logits: Tensor,
  882. class_queries_logits: Tensor,
  883. mask_labels: Tensor,
  884. class_labels: Tensor,
  885. auxiliary_predictions: dict[str, Tensor],
  886. ) -> dict[str, Tensor]:
  887. loss_dict: dict[str, Tensor] = self.criterion(
  888. masks_queries_logits=masks_queries_logits,
  889. class_queries_logits=class_queries_logits,
  890. mask_labels=mask_labels,
  891. class_labels=class_labels,
  892. auxiliary_predictions=auxiliary_predictions,
  893. )
  894. # weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
  895. for key, weight in self.weight_dict.items():
  896. for loss_key, loss in loss_dict.items():
  897. if key in loss_key:
  898. loss *= weight
  899. return loss_dict
  900. def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor:
  901. return sum(loss_dict.values())
  902. @merge_with_config_defaults
  903. @capture_outputs
  904. @auto_docstring
  905. def forward(
  906. self,
  907. pixel_values: Tensor,
  908. mask_labels: list[Tensor] | None = None,
  909. class_labels: list[Tensor] | None = None,
  910. patch_offsets: list[Tensor] | None = None,
  911. **kwargs: Unpack[TransformersKwargs],
  912. ) -> EomtForUniversalSegmentationOutput:
  913. r"""
  914. mask_labels (`list[torch.Tensor]`, *optional*):
  915. list of mask labels of shape `(num_labels, height, width)` to be fed to a model
  916. class_labels (`list[torch.LongTensor]`, *optional*):
  917. list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
  918. labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
  919. patch_offsets (`list[torch.Tensor]`, *optional*):
  920. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  921. """
  922. masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
  923. attention_mask = None
  924. if pixel_values is None:
  925. raise ValueError("You have to specify pixel_values")
  926. hidden_states = self.embeddings(pixel_values)
  927. for idx, layer_module in enumerate(self.layers):
  928. if idx == self.num_hidden_layers - self.config.num_blocks:
  929. query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
  930. hidden_states = torch.cat((query, hidden_states), dim=1)
  931. if idx >= self.num_hidden_layers - self.config.num_blocks and (
  932. self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
  933. ):
  934. norm_hidden_states = self.layernorm(hidden_states)
  935. masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
  936. masks_queries_logits_per_layer += (masks_queries_logits,)
  937. class_queries_logits_per_layer += (class_queries_logits,)
  938. attention_mask = torch.ones(
  939. hidden_states.shape[0],
  940. hidden_states.shape[1],
  941. hidden_states.shape[1],
  942. device=hidden_states.device,
  943. dtype=torch.bool,
  944. )
  945. interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
  946. interpolated_logits = interpolated_logits.view(
  947. interpolated_logits.size(0), interpolated_logits.size(1), -1
  948. )
  949. num_query_tokens = self.config.num_queries
  950. encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens
  951. # Set attention mask for queries to focus on encoder tokens based on interpolated logits
  952. attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
  953. # Disable attention mask for random query tokens.
  954. attention_mask = self._disable_attention_mask(
  955. attention_mask,
  956. prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
  957. num_query_tokens=num_query_tokens,
  958. encoder_start_tokens=encoder_start_tokens,
  959. device=attention_mask.device,
  960. )
  961. # Expand attention mask to 4d mask.
  962. attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
  963. attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9)
  964. hidden_states = layer_module(hidden_states, attention_mask)
  965. sequence_output = self.layernorm(hidden_states)
  966. masks_queries_logits, class_queries_logits = self.predict(sequence_output)
  967. masks_queries_logits_per_layer += (masks_queries_logits,)
  968. class_queries_logits_per_layer += (class_queries_logits,)
  969. loss = None
  970. if mask_labels is not None and class_labels is not None:
  971. loss = 0.0
  972. for masks_queries_logits, class_queries_logits in zip(
  973. masks_queries_logits_per_layer, class_queries_logits_per_layer
  974. ):
  975. loss_dict = self.get_loss_dict(
  976. masks_queries_logits=masks_queries_logits,
  977. class_queries_logits=class_queries_logits,
  978. mask_labels=mask_labels,
  979. class_labels=class_labels,
  980. auxiliary_predictions=None,
  981. )
  982. loss += self.get_loss(loss_dict)
  983. return EomtForUniversalSegmentationOutput(
  984. loss=loss,
  985. masks_queries_logits=masks_queries_logits,
  986. class_queries_logits=class_queries_logits,
  987. last_hidden_state=sequence_output,
  988. patch_offsets=patch_offsets,
  989. )
  990. def get_input_embeddings(self):
  991. return self.embeddings.patch_embeddings
  992. def predict(self, logits: torch.Tensor):
  993. query_tokens = logits[:, : self.config.num_queries, :]
  994. class_logits = self.class_predictor(query_tokens)
  995. prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
  996. prefix_tokens = prefix_tokens.transpose(1, 2)
  997. prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
  998. query_tokens = self.mask_head(query_tokens)
  999. prefix_tokens = self.upscale_block(prefix_tokens)
  1000. mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
  1001. return mask_logits, class_logits
  1002. @staticmethod
  1003. def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
  1004. if prob < 1:
  1005. # Generate random queries to disable based on the probs
  1006. random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
  1007. # Disable attention to the query tokens, considering the prefix tokens
  1008. attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
  1009. return attn_mask
  1010. __all__ = ["EomtPreTrainedModel", "EomtForUniversalSegmentation"]