modular_eomt.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch EoMT model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. import torch.nn.functional as F
  19. from huggingface_hub.dataclasses import strict
  20. from torch import Tensor, nn
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...file_utils import (
  24. ModelOutput,
  25. )
  26. from ...modeling_utils import PreTrainedModel
  27. from ...processing_utils import Unpack
  28. from ...utils import (
  29. TransformersKwargs,
  30. auto_docstring,
  31. logging,
  32. )
  33. from ...utils.generic import merge_with_config_defaults
  34. from ...utils.output_capturing import capture_outputs
  35. from ..dinov2.modeling_dinov2 import (
  36. Dinov2Embeddings,
  37. Dinov2Layer,
  38. Dinov2LayerScale,
  39. Dinov2PatchEmbeddings,
  40. )
  41. from ..mask2former.modeling_mask2former import Mask2FormerForUniversalSegmentation, Mask2FormerLoss
  42. from ..siglip.modeling_siglip import SiglipAttention
  43. from ..vit.configuration_vit import ViTConfig
  44. logger = logging.get_logger(__name__)
  45. @auto_docstring(checkpoint="tue-mps/coco_panoptic_eomt_large_640")
  46. @strict
  47. class EomtConfig(ViTConfig):
  48. r"""
  49. layerscale_value (`float`, *optional*, defaults to 1.0):
  50. Initial value for the LayerScale parameter.
  51. num_upscale_blocks (`int`, *optional*, defaults to 2):
  52. Number of upsampling blocks used in the decoder or segmentation head.
  53. use_swiglu_ffn (`bool`, *optional*, defaults to `False`):
  54. Whether to use the SwiGLU feedforward neural network.
  55. num_blocks (`int`, *optional*, defaults to 4):
  56. Number of feature blocks or stages in the architecture.
  57. no_object_weight (`float`, *optional*, defaults to 0.1):
  58. Loss weight for the 'no object' class in panoptic/instance segmentation.
  59. class_weight (`float`, *optional*, defaults to 2.0):
  60. Loss weight for classification targets.
  61. mask_weight (`float`, *optional*, defaults to 5.0):
  62. Loss weight for mask prediction.
  63. train_num_points (`int`, *optional*, defaults to 12544):
  64. Number of points to sample for mask loss computation during training.
  65. oversample_ratio (`float`, *optional*, defaults to 3.0):
  66. Oversampling ratio used in point sampling for mask training.
  67. importance_sample_ratio (`float`, *optional*, defaults to 0.75):
  68. Ratio of points to sample based on importance during training.
  69. num_queries (`int`, *optional*, defaults to 200):
  70. Number of object queries in the Transformer.
  71. num_register_tokens (`int`, *optional*, defaults to 4):
  72. Number of learnable register tokens added to the transformer input.
  73. Example:
  74. ```python
  75. >>> from transformers import EomtConfig, EomtForUniversalSegmentation
  76. >>> # Initialize configuration
  77. >>> config = EomtConfig()
  78. >>> # Initialize model
  79. >>> model = EomtForUniversalSegmentation(config)
  80. >>> # Access config
  81. >>> config = model.config
  82. ```"""
  83. model_type = "eomt"
  84. hidden_size: int = 1024
  85. num_hidden_layers: int = 24
  86. num_attention_heads: int = 16
  87. mlp_ratio: int = 4
  88. hidden_act: str = "gelu"
  89. hidden_dropout_prob: float | int = 0.0
  90. initializer_range: float = 0.02
  91. layer_norm_eps: float = 1e-6
  92. image_size: int | list[int] | tuple[int, int] = 640
  93. patch_size: int | list[int] | tuple[int, int] = 16
  94. num_channels: int = 3
  95. layerscale_value: float = 1.0
  96. drop_path_rate: float | int = 0.0
  97. num_upscale_blocks: int = 2
  98. attention_dropout: float | int = 0.0
  99. use_swiglu_ffn: bool = False
  100. num_blocks: int = 4
  101. no_object_weight: float = 0.1
  102. class_weight: float = 2.0
  103. mask_weight: float = 5.0
  104. dice_weight: float = 5.0
  105. train_num_points: int = 12544
  106. oversample_ratio: float = 3.0
  107. importance_sample_ratio: float = 0.75
  108. num_queries: int = 200
  109. num_register_tokens: int = 4
  110. intermediate_size = AttributeError()
  111. qkv_bias = AttributeError()
  112. pooler_act = AttributeError()
  113. pooler_output_size = AttributeError()
  114. encoder_stride = AttributeError()
  115. attention_probs_dropout_prob = AttributeError()
  116. def __post_init__(self, **kwargs):
  117. raise AttributeError("Not needed for Eomt")
  118. @dataclass
  119. @auto_docstring(
  120. custom_intro="""
  121. Class for outputs of [`EomtForUniversalSegmentationOutput`].
  122. This output can be directly passed to [`~EomtImageProcessor.post_process_semantic_segmentation`] or
  123. [`~EomtImageProcessor.post_process_instance_segmentation`] or
  124. [`~EomtImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
  125. [`~EomtImageProcessor] for details regarding usage.
  126. """
  127. )
  128. class EomtForUniversalSegmentationOutput(ModelOutput):
  129. r"""
  130. loss (`torch.Tensor`, *optional*):
  131. The computed loss, returned when labels are present.
  132. class_queries_logits (`torch.FloatTensor`):
  133. A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
  134. query. Note the `+ 1` is needed because we incorporate the null class.
  135. masks_queries_logits (`torch.FloatTensor`):
  136. A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
  137. query.
  138. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  139. Last hidden states (final feature map) of the last layer.
  140. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  141. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  142. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
  143. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  144. Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  145. sequence_length)`. Self and Cross Attentions weights from transformer decoder.
  146. patch_offsets (`list[torch.Tensor]`, *optional*):
  147. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  148. """
  149. loss: torch.FloatTensor | None = None
  150. class_queries_logits: torch.FloatTensor | None = None
  151. masks_queries_logits: torch.FloatTensor | None = None
  152. last_hidden_state: torch.FloatTensor | None = None
  153. hidden_states: tuple[torch.FloatTensor] | None = None
  154. attentions: tuple[torch.FloatTensor] | None = None
  155. patch_offsets: list[torch.Tensor] | None = None
  156. class EomtLoss(Mask2FormerLoss):
  157. pass
  158. class EomtPatchEmbeddings(Dinov2PatchEmbeddings):
  159. pass
  160. class EomtEmbeddings(Dinov2Embeddings):
  161. def __init__(self, config: EomtConfig) -> None:
  162. nn.Module.__init__(self)
  163. self.config = config
  164. self.patch_size = config.patch_size
  165. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  166. self.register_tokens = nn.Parameter(torch.zeros(1, config.num_register_tokens, config.hidden_size))
  167. self.patch_embeddings = EomtPatchEmbeddings(config)
  168. num_patches = self.patch_embeddings.num_patches
  169. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  170. self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS]
  171. self.position_embeddings = nn.Embedding(num_patches, config.hidden_size)
  172. self.register_buffer("position_ids", torch.arange(num_patches).expand((1, -1)), persistent=False)
  173. def interpolate_pos_encoding(self):
  174. raise AttributeError("Not needed for Eomt Model")
  175. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  176. batch_size, _, _, _ = pixel_values.shape
  177. target_dtype = self.patch_embeddings.projection.weight.dtype
  178. embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
  179. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  180. register_tokens = self.register_tokens.expand(batch_size, -1, -1)
  181. embeddings = embeddings + self.position_embeddings(self.position_ids)
  182. embeddings = torch.cat([cls_tokens, register_tokens, embeddings], dim=1)
  183. embeddings = self.dropout(embeddings)
  184. return embeddings
  185. class EomtAttention(SiglipAttention):
  186. pass
  187. class EomtLayerScale(Dinov2LayerScale):
  188. pass
  189. class EomtLayer(Dinov2Layer):
  190. def forward(
  191. self,
  192. hidden_states: torch.Tensor,
  193. attention_mask: torch.Tensor | None = None,
  194. ) -> torch.Tensor:
  195. hidden_states_norm = self.norm1(hidden_states)
  196. self_attention_output, _ = self.attention(hidden_states_norm, attention_mask)
  197. self_attention_output = self.layer_scale1(self_attention_output)
  198. # first residual connection
  199. hidden_states = self.drop_path(self_attention_output) + hidden_states
  200. # in Eomt, layernorm is also applied after self-attention
  201. layer_output = self.norm2(hidden_states)
  202. layer_output = self.mlp(layer_output)
  203. layer_output = self.layer_scale2(layer_output)
  204. # second residual connection
  205. layer_output = self.drop_path(layer_output) + hidden_states
  206. return layer_output
  207. class EomtLayerNorm2d(nn.LayerNorm):
  208. def __init__(self, num_channels, eps=1e-6, affine=True):
  209. super().__init__(num_channels, eps=eps, elementwise_affine=affine)
  210. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  211. hidden_state = hidden_state.permute(0, 2, 3, 1)
  212. hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
  213. hidden_state = hidden_state.permute(0, 3, 1, 2)
  214. return hidden_state
  215. class EomtScaleLayer(nn.Module):
  216. def __init__(self, config: EomtConfig):
  217. super().__init__()
  218. hidden_size = config.hidden_size
  219. self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
  220. self.activation = ACT2FN[config.hidden_act]
  221. self.conv2 = nn.Conv2d(
  222. hidden_size,
  223. hidden_size,
  224. kernel_size=3,
  225. padding=1,
  226. groups=hidden_size,
  227. bias=False,
  228. )
  229. self.layernorm2d = EomtLayerNorm2d(hidden_size)
  230. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  231. hidden_states = self.conv1(hidden_states)
  232. hidden_states = self.activation(hidden_states)
  233. hidden_states = self.conv2(hidden_states)
  234. hidden_states = self.layernorm2d(hidden_states)
  235. return hidden_states
  236. class EomtScaleBlock(nn.Module):
  237. def __init__(self, config: EomtConfig):
  238. super().__init__()
  239. self.num_blocks = config.num_upscale_blocks
  240. self.block = nn.ModuleList([EomtScaleLayer(config) for _ in range(self.num_blocks)])
  241. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  242. for block in self.block:
  243. hidden_states = block(hidden_states)
  244. return hidden_states
  245. class EomtMaskHead(nn.Module):
  246. def __init__(self, config: EomtConfig):
  247. super().__init__()
  248. hidden_size = config.hidden_size
  249. self.fc1 = nn.Linear(hidden_size, hidden_size)
  250. self.fc2 = nn.Linear(hidden_size, hidden_size)
  251. self.fc3 = nn.Linear(hidden_size, hidden_size)
  252. self.activation = ACT2FN[config.hidden_act]
  253. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  254. hidden_states = self.activation(self.fc1(hidden_states))
  255. hidden_states = self.activation(self.fc2(hidden_states))
  256. hidden_states = self.fc3(hidden_states)
  257. return hidden_states
  258. @auto_docstring
  259. class EomtPreTrainedModel(PreTrainedModel):
  260. """
  261. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  262. models.
  263. """
  264. config: EomtConfig
  265. base_model_prefix = "eomt"
  266. main_input_name = "pixel_values"
  267. input_modalities = ("image",)
  268. supports_gradient_checkpointing = False
  269. _no_split_modules = ["EomtLayer"]
  270. _supports_sdpa = True
  271. _can_record_outputs = {
  272. "hidden_states": EomtLayer,
  273. "attentions": EomtAttention,
  274. }
  275. @torch.no_grad()
  276. def _init_weights(self, module: nn.Module) -> None:
  277. std = self.config.initializer_range
  278. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  279. init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  280. if module.bias is not None:
  281. fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(module.weight)
  282. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  283. init.uniform_(module.bias, -bound, bound)
  284. elif isinstance(module, nn.LayerNorm):
  285. init.ones_(module.weight)
  286. init.zeros_(module.bias)
  287. elif isinstance(module, nn.Embedding):
  288. init.normal_(module.weight, mean=0.0, std=1)
  289. # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
  290. if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
  291. init.zeros_(module.weight[module.padding_idx])
  292. elif isinstance(module, EomtLayerScale):
  293. if hasattr(module, "lambda1"):
  294. init.constant_(module.lambda1, self.config.layerscale_value)
  295. elif isinstance(module, EomtEmbeddings):
  296. init.trunc_normal_(module.cls_token, mean=0.0, std=std)
  297. init.zeros_(module.register_tokens)
  298. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  299. elif isinstance(module, EomtLoss):
  300. empty_weight = torch.ones(module.num_labels + 1)
  301. empty_weight[-1] = module.eos_coef
  302. init.copy_(module.empty_weight, empty_weight)
  303. elif isinstance(module, EomtForUniversalSegmentation):
  304. init.ones_(module.attn_mask_probs)
  305. @auto_docstring(
  306. custom_intro="""
  307. The EoMT Model with head on top for instance/semantic/panoptic segmentation.
  308. """
  309. )
  310. class EomtForUniversalSegmentation(Mask2FormerForUniversalSegmentation):
  311. def __init__(self, config: EomtConfig):
  312. PreTrainedModel.__init__(self, config)
  313. self.config = config
  314. self.num_hidden_layers = config.num_hidden_layers
  315. self.embeddings = EomtEmbeddings(config)
  316. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  317. self.query = nn.Embedding(config.num_queries, config.hidden_size)
  318. self.layers = nn.ModuleList([EomtLayer(config) for _ in range(config.num_hidden_layers)])
  319. self.upscale_block = EomtScaleBlock(config)
  320. self.mask_head = EomtMaskHead(config)
  321. self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
  322. self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
  323. self.weight_dict: dict[str, float] = {
  324. "loss_cross_entropy": config.class_weight,
  325. "loss_mask": config.mask_weight,
  326. "loss_dice": config.dice_weight,
  327. }
  328. self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict)
  329. self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
  330. self.post_init()
  331. def get_input_embeddings(self):
  332. return self.embeddings.patch_embeddings
  333. def get_auxiliary_logits(self):
  334. raise AttributeError("Note needed for Eomt Model.")
  335. def predict(self, logits: torch.Tensor):
  336. query_tokens = logits[:, : self.config.num_queries, :]
  337. class_logits = self.class_predictor(query_tokens)
  338. prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
  339. prefix_tokens = prefix_tokens.transpose(1, 2)
  340. prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
  341. query_tokens = self.mask_head(query_tokens)
  342. prefix_tokens = self.upscale_block(prefix_tokens)
  343. mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
  344. return mask_logits, class_logits
  345. @staticmethod
  346. def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
  347. if prob < 1:
  348. # Generate random queries to disable based on the probs
  349. random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
  350. # Disable attention to the query tokens, considering the prefix tokens
  351. attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
  352. return attn_mask
  353. @merge_with_config_defaults
  354. @capture_outputs
  355. @auto_docstring
  356. def forward(
  357. self,
  358. pixel_values: Tensor,
  359. mask_labels: list[Tensor] | None = None,
  360. class_labels: list[Tensor] | None = None,
  361. patch_offsets: list[Tensor] | None = None,
  362. **kwargs: Unpack[TransformersKwargs],
  363. ) -> EomtForUniversalSegmentationOutput:
  364. r"""
  365. mask_labels (`list[torch.Tensor]`, *optional*):
  366. list of mask labels of shape `(num_labels, height, width)` to be fed to a model
  367. class_labels (`list[torch.LongTensor]`, *optional*):
  368. list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
  369. labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
  370. patch_offsets (`list[torch.Tensor]`, *optional*):
  371. list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
  372. """
  373. masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
  374. attention_mask = None
  375. if pixel_values is None:
  376. raise ValueError("You have to specify pixel_values")
  377. hidden_states = self.embeddings(pixel_values)
  378. for idx, layer_module in enumerate(self.layers):
  379. if idx == self.num_hidden_layers - self.config.num_blocks:
  380. query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
  381. hidden_states = torch.cat((query, hidden_states), dim=1)
  382. if idx >= self.num_hidden_layers - self.config.num_blocks and (
  383. self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
  384. ):
  385. norm_hidden_states = self.layernorm(hidden_states)
  386. masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
  387. masks_queries_logits_per_layer += (masks_queries_logits,)
  388. class_queries_logits_per_layer += (class_queries_logits,)
  389. attention_mask = torch.ones(
  390. hidden_states.shape[0],
  391. hidden_states.shape[1],
  392. hidden_states.shape[1],
  393. device=hidden_states.device,
  394. dtype=torch.bool,
  395. )
  396. interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
  397. interpolated_logits = interpolated_logits.view(
  398. interpolated_logits.size(0), interpolated_logits.size(1), -1
  399. )
  400. num_query_tokens = self.config.num_queries
  401. encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens
  402. # Set attention mask for queries to focus on encoder tokens based on interpolated logits
  403. attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
  404. # Disable attention mask for random query tokens.
  405. attention_mask = self._disable_attention_mask(
  406. attention_mask,
  407. prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
  408. num_query_tokens=num_query_tokens,
  409. encoder_start_tokens=encoder_start_tokens,
  410. device=attention_mask.device,
  411. )
  412. # Expand attention mask to 4d mask.
  413. attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
  414. attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9)
  415. hidden_states = layer_module(hidden_states, attention_mask)
  416. sequence_output = self.layernorm(hidden_states)
  417. masks_queries_logits, class_queries_logits = self.predict(sequence_output)
  418. masks_queries_logits_per_layer += (masks_queries_logits,)
  419. class_queries_logits_per_layer += (class_queries_logits,)
  420. loss = None
  421. if mask_labels is not None and class_labels is not None:
  422. loss = 0.0
  423. for masks_queries_logits, class_queries_logits in zip(
  424. masks_queries_logits_per_layer, class_queries_logits_per_layer
  425. ):
  426. loss_dict = self.get_loss_dict(
  427. masks_queries_logits=masks_queries_logits,
  428. class_queries_logits=class_queries_logits,
  429. mask_labels=mask_labels,
  430. class_labels=class_labels,
  431. auxiliary_predictions=None,
  432. )
  433. loss += self.get_loss(loss_dict)
  434. return EomtForUniversalSegmentationOutput(
  435. loss=loss,
  436. masks_queries_logits=masks_queries_logits,
  437. class_queries_logits=class_queries_logits,
  438. last_hidden_state=sequence_output,
  439. patch_offsets=patch_offsets,
  440. )
  441. __all__ = ["EomtConfig", "EomtPreTrainedModel", "EomtForUniversalSegmentation"]