modeling_bros.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971
  1. # Copyright 2023-present NAVER Corp, The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Bros model."""
  15. import math
  16. from dataclasses import dataclass
  17. import torch
  18. from torch import nn
  19. from torch.nn import CrossEntropyLoss
  20. from ... import initialization as init
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithCrossAttentions,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. TokenClassifierOutput,
  27. )
  28. from ...modeling_utils import PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...pytorch_utils import apply_chunking_to_forward
  31. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
  32. from ...utils.generic import merge_with_config_defaults
  33. from ...utils.output_capturing import OutputRecorder, capture_outputs
  34. from .configuration_bros import BrosConfig
  35. logger = logging.get_logger(__name__)
  36. @dataclass
  37. @auto_docstring(
  38. custom_intro="""
  39. Base class for outputs of token classification models.
  40. """
  41. )
  42. class BrosSpadeOutput(ModelOutput):
  43. r"""
  44. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  45. Classification loss.
  46. initial_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
  47. Classification scores for entity initial tokens (before SoftMax).
  48. subsequent_token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, sequence_length+1)`):
  49. Classification scores for entity sequence tokens (before SoftMax).
  50. """
  51. loss: torch.FloatTensor | None = None
  52. initial_token_logits: torch.FloatTensor | None = None
  53. subsequent_token_logits: torch.FloatTensor | None = None
  54. hidden_states: tuple[torch.FloatTensor] | None = None
  55. attentions: tuple[torch.FloatTensor] | None = None
  56. class BrosPositionalEmbedding1D(nn.Module):
  57. # Reference: https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py#L15
  58. def __init__(self, config):
  59. super().__init__()
  60. self.dim_bbox_sinusoid_emb_1d = config.dim_bbox_sinusoid_emb_1d
  61. inv_freq = 1 / (
  62. 10000 ** (torch.arange(0.0, self.dim_bbox_sinusoid_emb_1d, 2.0) / self.dim_bbox_sinusoid_emb_1d)
  63. )
  64. self.register_buffer("inv_freq", inv_freq)
  65. def forward(self, pos_seq: torch.Tensor) -> torch.Tensor:
  66. seq_size = pos_seq.size()
  67. b1, b2, b3 = seq_size
  68. sinusoid_inp = pos_seq.view(b1, b2, b3, 1) * self.inv_freq.view(1, 1, 1, self.dim_bbox_sinusoid_emb_1d // 2)
  69. pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
  70. return pos_emb
  71. class BrosPositionalEmbedding2D(nn.Module):
  72. def __init__(self, config):
  73. super().__init__()
  74. self.dim_bbox = config.dim_bbox
  75. self.x_pos_emb = BrosPositionalEmbedding1D(config)
  76. self.y_pos_emb = BrosPositionalEmbedding1D(config)
  77. def forward(self, bbox: torch.Tensor) -> torch.Tensor:
  78. stack = []
  79. for i in range(self.dim_bbox):
  80. if i % 2 == 0:
  81. stack.append(self.x_pos_emb(bbox[..., i]))
  82. else:
  83. stack.append(self.y_pos_emb(bbox[..., i]))
  84. bbox_pos_emb = torch.cat(stack, dim=-1)
  85. return bbox_pos_emb
  86. class BrosBboxEmbeddings(nn.Module):
  87. def __init__(self, config):
  88. super().__init__()
  89. self.bbox_sinusoid_emb = BrosPositionalEmbedding2D(config)
  90. self.bbox_projection = nn.Linear(config.dim_bbox_sinusoid_emb_2d, config.dim_bbox_projection, bias=False)
  91. def forward(self, bbox: torch.Tensor):
  92. bbox_t = bbox.transpose(0, 1)
  93. bbox_pos = bbox_t[None, :, :, :] - bbox_t[:, None, :, :]
  94. bbox_pos_emb = self.bbox_sinusoid_emb(bbox_pos)
  95. bbox_pos_emb = self.bbox_projection(bbox_pos_emb)
  96. return bbox_pos_emb
  97. class BrosTextEmbeddings(nn.Module):
  98. """Construct the embeddings from word, position and token_type embeddings."""
  99. def __init__(self, config):
  100. super().__init__()
  101. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  102. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  103. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  104. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  105. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  106. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  107. self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
  108. self.register_buffer(
  109. "token_type_ids",
  110. torch.zeros(
  111. self.position_ids.size(),
  112. dtype=torch.long,
  113. device=self.position_ids.device,
  114. ),
  115. persistent=False,
  116. )
  117. def forward(
  118. self,
  119. input_ids: torch.Tensor | None = None,
  120. token_type_ids: torch.Tensor | None = None,
  121. position_ids: torch.Tensor | None = None,
  122. inputs_embeds: torch.Tensor | None = None,
  123. ) -> torch.Tensor:
  124. if input_ids is not None:
  125. input_shape = input_ids.size()
  126. else:
  127. input_shape = inputs_embeds.size()[:-1]
  128. seq_length = input_shape[1]
  129. if position_ids is None:
  130. position_ids = self.position_ids[:, :seq_length]
  131. if token_type_ids is None:
  132. if hasattr(self, "token_type_ids"):
  133. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  134. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  135. token_type_ids = buffered_token_type_ids_expanded
  136. else:
  137. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  138. if inputs_embeds is None:
  139. inputs_embeds = self.word_embeddings(input_ids)
  140. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  141. embeddings = inputs_embeds + token_type_embeddings
  142. position_embeddings = self.position_embeddings(position_ids)
  143. embeddings += position_embeddings
  144. embeddings = self.LayerNorm(embeddings)
  145. embeddings = self.dropout(embeddings)
  146. return embeddings
  147. class BrosSelfAttention(nn.Module):
  148. def __init__(self, config):
  149. super().__init__()
  150. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  151. raise ValueError(
  152. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  153. f"heads ({config.num_attention_heads})"
  154. )
  155. self.num_attention_heads = config.num_attention_heads
  156. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  157. self.all_head_size = self.num_attention_heads * self.attention_head_size
  158. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  159. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  160. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  161. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  162. self.is_decoder = config.is_decoder
  163. def forward(
  164. self,
  165. hidden_states: torch.Tensor,
  166. bbox_pos_emb: torch.Tensor,
  167. attention_mask: torch.Tensor | None = None,
  168. encoder_hidden_states: torch.Tensor | None = None,
  169. encoder_attention_mask: torch.Tensor | None = None,
  170. ) -> tuple[torch.Tensor]:
  171. hidden_shape = (hidden_states.shape[0], -1, self.num_attention_heads, self.attention_head_size)
  172. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  173. # If this is instantiated as a cross-attention module, the keys
  174. # and values come from an encoder; the attention mask needs to be
  175. # such that the encoder's padding tokens are not attended to.
  176. is_cross_attention = encoder_hidden_states is not None
  177. if is_cross_attention:
  178. key_layer = self.key(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
  179. value_layer = self.value(encoder_hidden_states).view(hidden_shape).transpose(1, 2)
  180. attention_mask = encoder_attention_mask
  181. else:
  182. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  183. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  184. # Take the dot product between "query" and "key" to get the raw attention scores.
  185. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  186. # bbox positional encoding
  187. batch_size, n_head, seq_length, d_head = query_layer.shape
  188. bbox_pos_emb = bbox_pos_emb.view(seq_length, seq_length, batch_size, d_head)
  189. bbox_pos_emb = bbox_pos_emb.permute([2, 0, 1, 3])
  190. bbox_pos_scores = torch.einsum("bnid,bijd->bnij", (query_layer, bbox_pos_emb))
  191. attention_scores = attention_scores + bbox_pos_scores
  192. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  193. if attention_mask is not None:
  194. # Apply the attention mask is (precomputed for all layers in BrosModel forward() function)
  195. attention_scores = attention_scores + attention_mask
  196. # Normalize the attention scores to probabilities.
  197. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  198. # This is actually dropping out entire tokens to attend to, which might
  199. # seem a bit unusual, but is taken from the original Transformer paper.
  200. attention_probs = self.dropout(attention_probs)
  201. context_layer = torch.matmul(attention_probs, value_layer)
  202. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  203. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  204. context_layer = context_layer.view(*new_context_layer_shape)
  205. return context_layer, attention_probs
  206. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Bros
  207. class BrosSelfOutput(nn.Module):
  208. def __init__(self, config):
  209. super().__init__()
  210. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  211. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  212. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  213. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  214. hidden_states = self.dense(hidden_states)
  215. hidden_states = self.dropout(hidden_states)
  216. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  217. return hidden_states
  218. class BrosAttention(nn.Module):
  219. def __init__(self, config):
  220. super().__init__()
  221. self.self = BrosSelfAttention(config)
  222. self.output = BrosSelfOutput(config)
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. bbox_pos_emb: torch.Tensor,
  227. attention_mask: torch.Tensor | None = None,
  228. encoder_hidden_states: torch.Tensor | None = None,
  229. encoder_attention_mask: torch.Tensor | None = None,
  230. ) -> torch.Tensor:
  231. residual = hidden_states
  232. hidden_states, _ = self.self(
  233. hidden_states,
  234. bbox_pos_emb=bbox_pos_emb,
  235. attention_mask=attention_mask,
  236. encoder_hidden_states=encoder_hidden_states,
  237. encoder_attention_mask=encoder_attention_mask,
  238. )
  239. hidden_states = self.output(hidden_states, residual)
  240. return hidden_states
  241. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Bros
  242. class BrosIntermediate(nn.Module):
  243. def __init__(self, config):
  244. super().__init__()
  245. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  246. if isinstance(config.hidden_act, str):
  247. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  248. else:
  249. self.intermediate_act_fn = config.hidden_act
  250. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  251. hidden_states = self.dense(hidden_states)
  252. hidden_states = self.intermediate_act_fn(hidden_states)
  253. return hidden_states
  254. class BrosOutput(nn.Module):
  255. def __init__(self, config):
  256. super().__init__()
  257. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  258. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  259. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  260. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  261. hidden_states = self.dense(hidden_states)
  262. hidden_states = self.dropout(hidden_states)
  263. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  264. return hidden_states
  265. class BrosLayer(GradientCheckpointingLayer):
  266. def __init__(self, config):
  267. super().__init__()
  268. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  269. self.seq_len_dim = 1
  270. self.attention = BrosAttention(config)
  271. self.is_decoder = config.is_decoder
  272. self.add_cross_attention = config.add_cross_attention
  273. if self.add_cross_attention:
  274. if not self.is_decoder:
  275. raise Exception(f"{self} should be used as a decoder model if cross attention is added")
  276. self.crossattention = BrosAttention(config)
  277. self.intermediate = BrosIntermediate(config)
  278. self.output = BrosOutput(config)
  279. def forward(
  280. self,
  281. hidden_states: torch.Tensor,
  282. bbox_pos_emb: torch.Tensor,
  283. attention_mask: torch.FloatTensor | None = None,
  284. encoder_hidden_states: torch.FloatTensor | None = None,
  285. encoder_attention_mask: torch.FloatTensor | None = None,
  286. **kwargs: Unpack[TransformersKwargs],
  287. ) -> torch.Tensor:
  288. hidden_states = self.attention(
  289. hidden_states,
  290. bbox_pos_emb=bbox_pos_emb,
  291. attention_mask=attention_mask,
  292. )
  293. if self.is_decoder and encoder_hidden_states is not None:
  294. if hasattr(self, "crossattention"):
  295. raise Exception(
  296. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
  297. )
  298. hidden_states, _ = self.crossattention(
  299. hidden_states,
  300. attention_mask=attention_mask,
  301. encoder_hidden_states=encoder_hidden_states,
  302. encoder_attention_mask=encoder_attention_mask,
  303. **kwargs,
  304. )
  305. hidden_states = apply_chunking_to_forward(
  306. self.feed_forward_chunk,
  307. self.chunk_size_feed_forward,
  308. self.seq_len_dim,
  309. hidden_states,
  310. )
  311. return hidden_states
  312. def feed_forward_chunk(self, attention_output):
  313. intermediate_output = self.intermediate(attention_output)
  314. layer_output = self.output(intermediate_output, attention_output)
  315. return layer_output
  316. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->Bros
  317. class BrosPooler(nn.Module):
  318. def __init__(self, config):
  319. super().__init__()
  320. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  321. self.activation = nn.Tanh()
  322. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  323. # We "pool" the model by simply taking the hidden state corresponding
  324. # to the first token.
  325. first_token_tensor = hidden_states[:, 0]
  326. pooled_output = self.dense(first_token_tensor)
  327. pooled_output = self.activation(pooled_output)
  328. return pooled_output
  329. class BrosRelationExtractor(nn.Module):
  330. def __init__(self, config):
  331. super().__init__()
  332. self.n_relations = config.n_relations
  333. self.backbone_hidden_size = config.hidden_size
  334. self.head_hidden_size = config.hidden_size
  335. self.classifier_dropout_prob = config.classifier_dropout_prob
  336. self.drop = nn.Dropout(self.classifier_dropout_prob)
  337. self.query = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
  338. self.key = nn.Linear(self.backbone_hidden_size, self.n_relations * self.head_hidden_size)
  339. self.dummy_node = nn.Parameter(torch.zeros(1, self.backbone_hidden_size))
  340. def forward(self, query_layer: torch.Tensor, key_layer: torch.Tensor):
  341. query_layer = self.query(self.drop(query_layer))
  342. dummy_vec = self.dummy_node.unsqueeze(0).repeat(1, key_layer.size(1), 1)
  343. key_layer = torch.cat([key_layer, dummy_vec], axis=0)
  344. key_layer = self.key(self.drop(key_layer))
  345. query_layer = query_layer.view(
  346. query_layer.size(0), query_layer.size(1), self.n_relations, self.head_hidden_size
  347. )
  348. key_layer = key_layer.view(key_layer.size(0), key_layer.size(1), self.n_relations, self.head_hidden_size)
  349. relation_score = torch.matmul(
  350. query_layer.permute(2, 1, 0, 3), key_layer.permute(2, 1, 3, 0)
  351. ) # equivalent to torch.einsum("ibnd,jbnd->nbij", (query_layer, key_layer))
  352. return relation_score
  353. @auto_docstring
  354. class BrosPreTrainedModel(PreTrainedModel):
  355. config: BrosConfig
  356. base_model_prefix = "bros"
  357. _can_record_outputs = {
  358. "hidden_states": BrosLayer,
  359. "attentions": OutputRecorder(BrosSelfAttention, index=1, layer_name="attention"),
  360. "cross_attentions": OutputRecorder(BrosSelfAttention, index=1, layer_name="crossattention"),
  361. }
  362. @torch.no_grad()
  363. def _init_weights(self, module: nn.Module):
  364. """Initialize the weights"""
  365. super()._init_weights(module)
  366. std = self.config.initializer_range
  367. if isinstance(module, BrosRelationExtractor):
  368. init.normal_(module.dummy_node, std=std)
  369. elif isinstance(module, BrosTextEmbeddings):
  370. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  371. init.zeros_(module.token_type_ids)
  372. elif isinstance(module, BrosPositionalEmbedding1D):
  373. inv_freq = 1 / (
  374. 10000 ** (torch.arange(0.0, module.dim_bbox_sinusoid_emb_1d, 2.0) / module.dim_bbox_sinusoid_emb_1d)
  375. )
  376. init.copy_(module.inv_freq, inv_freq)
  377. class BrosEncoder(BrosPreTrainedModel):
  378. def __init__(self, config):
  379. super().__init__(config)
  380. self.layer = nn.ModuleList([BrosLayer(config) for _ in range(config.num_hidden_layers)])
  381. self.post_init()
  382. @merge_with_config_defaults
  383. @capture_outputs
  384. def forward(
  385. self,
  386. hidden_states: torch.Tensor,
  387. bbox_pos_emb: torch.Tensor,
  388. attention_mask: torch.FloatTensor | None = None,
  389. encoder_hidden_states: torch.FloatTensor | None = None,
  390. encoder_attention_mask: torch.FloatTensor | None = None,
  391. **kwargs: Unpack[TransformersKwargs],
  392. ) -> tuple[torch.Tensor] | BaseModelOutputWithCrossAttentions:
  393. for layer_module in self.layer:
  394. hidden_states = layer_module(
  395. hidden_states,
  396. bbox_pos_emb=bbox_pos_emb,
  397. attention_mask=attention_mask,
  398. encoder_hidden_states=encoder_hidden_states,
  399. encoder_attention_mask=encoder_attention_mask,
  400. **kwargs,
  401. )
  402. return BaseModelOutputWithCrossAttentions(
  403. last_hidden_state=hidden_states,
  404. )
  405. @auto_docstring
  406. class BrosModel(BrosPreTrainedModel):
  407. def __init__(self, config, add_pooling_layer=True):
  408. r"""
  409. add_pooling_layer (bool, *optional*, defaults to `True`):
  410. Whether to add a pooling layer
  411. """
  412. super().__init__(config)
  413. self.config = config
  414. self.embeddings = BrosTextEmbeddings(config)
  415. self.bbox_embeddings = BrosBboxEmbeddings(config)
  416. self.encoder = BrosEncoder(config)
  417. self.pooler = BrosPooler(config) if add_pooling_layer else None
  418. self.post_init()
  419. def get_input_embeddings(self):
  420. return self.embeddings.word_embeddings
  421. def set_input_embeddings(self, value):
  422. self.embeddings.word_embeddings = value
  423. @can_return_tuple
  424. @auto_docstring
  425. def forward(
  426. self,
  427. input_ids: torch.Tensor | None = None,
  428. bbox: torch.Tensor | None = None,
  429. attention_mask: torch.Tensor | None = None,
  430. token_type_ids: torch.Tensor | None = None,
  431. position_ids: torch.Tensor | None = None,
  432. inputs_embeds: torch.Tensor | None = None,
  433. encoder_hidden_states: torch.Tensor | None = None,
  434. encoder_attention_mask: torch.Tensor | None = None,
  435. **kwargs: Unpack[TransformersKwargs],
  436. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  437. r"""
  438. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  439. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  440. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  441. bounding box.
  442. Examples:
  443. ```python
  444. >>> import torch
  445. >>> from transformers import BrosProcessor, BrosModel
  446. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  447. >>> model = BrosModel.from_pretrained("jinho8345/bros-base-uncased")
  448. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  449. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  450. >>> encoding["bbox"] = bbox
  451. >>> outputs = model(**encoding)
  452. >>> last_hidden_states = outputs.last_hidden_state
  453. ```"""
  454. if (input_ids is None) ^ (inputs_embeds is not None):
  455. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  456. if bbox is None:
  457. raise ValueError("You have to specify bbox")
  458. embedding_output = self.embeddings(
  459. input_ids=input_ids,
  460. position_ids=position_ids,
  461. token_type_ids=token_type_ids,
  462. inputs_embeds=inputs_embeds,
  463. )
  464. input_shape = embedding_output.shape[:-1]
  465. device = embedding_output.device
  466. if attention_mask is None:
  467. attention_mask = torch.ones(input_shape, device=device)
  468. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  469. # ourselves in which case we just need to make it broadcastable to all heads.
  470. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  471. # If a 2D or 3D attention mask is provided for the cross-attention
  472. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  473. if self.config.is_decoder and encoder_hidden_states is not None:
  474. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  475. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  476. if encoder_attention_mask is None:
  477. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  478. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  479. else:
  480. encoder_extended_attention_mask = None
  481. # if bbox has 2 points (4 float tensors) per token, convert it to 4 points (8 float tensors) per token
  482. if bbox.shape[-1] == 4:
  483. bbox = bbox[:, :, [0, 1, 2, 1, 2, 3, 0, 3]]
  484. scaled_bbox = bbox * self.config.bbox_scale
  485. bbox_position_embeddings = self.bbox_embeddings(scaled_bbox)
  486. encoder_outputs: BaseModelOutputWithCrossAttentions = self.encoder(
  487. embedding_output,
  488. bbox_pos_emb=bbox_position_embeddings,
  489. attention_mask=extended_attention_mask,
  490. encoder_hidden_states=encoder_hidden_states,
  491. encoder_attention_mask=encoder_extended_attention_mask,
  492. **kwargs,
  493. )
  494. sequence_output = encoder_outputs[0]
  495. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  496. return BaseModelOutputWithPoolingAndCrossAttentions(
  497. last_hidden_state=sequence_output,
  498. pooler_output=pooled_output,
  499. hidden_states=encoder_outputs.hidden_states,
  500. attentions=encoder_outputs.attentions,
  501. cross_attentions=encoder_outputs.cross_attentions,
  502. )
  503. @auto_docstring
  504. class BrosForTokenClassification(BrosPreTrainedModel):
  505. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  506. def __init__(self, config):
  507. super().__init__(config)
  508. self.num_labels = config.num_labels
  509. self.bros = BrosModel(config)
  510. classifier_dropout = (
  511. config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
  512. )
  513. self.dropout = nn.Dropout(classifier_dropout)
  514. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  515. self.post_init()
  516. @can_return_tuple
  517. @auto_docstring
  518. def forward(
  519. self,
  520. input_ids: torch.Tensor | None = None,
  521. bbox: torch.Tensor | None = None,
  522. attention_mask: torch.Tensor | None = None,
  523. bbox_first_token_mask: torch.Tensor | None = None,
  524. token_type_ids: torch.Tensor | None = None,
  525. position_ids: torch.Tensor | None = None,
  526. inputs_embeds: torch.Tensor | None = None,
  527. labels: torch.Tensor | None = None,
  528. **kwargs: Unpack[TransformersKwargs],
  529. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  530. r"""
  531. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  532. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  533. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  534. bounding box.
  535. bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  536. Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
  537. - 1 for tokens that are **not masked**,
  538. - 0 for tokens that are **masked**.
  539. Examples:
  540. ```python
  541. >>> import torch
  542. >>> from transformers import BrosProcessor, BrosForTokenClassification
  543. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  544. >>> model = BrosForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
  545. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  546. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  547. >>> encoding["bbox"] = bbox
  548. >>> outputs = model(**encoding)
  549. ```"""
  550. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bros(
  551. input_ids,
  552. bbox=bbox,
  553. attention_mask=attention_mask,
  554. token_type_ids=token_type_ids,
  555. position_ids=position_ids,
  556. inputs_embeds=inputs_embeds,
  557. **kwargs,
  558. )
  559. sequence_output = outputs[0]
  560. sequence_output = self.dropout(sequence_output)
  561. logits = self.classifier(sequence_output)
  562. loss = None
  563. if labels is not None:
  564. loss_fct = CrossEntropyLoss()
  565. if bbox_first_token_mask is not None:
  566. bbox_first_token_mask = bbox_first_token_mask.view(-1)
  567. loss = loss_fct(
  568. logits.view(-1, self.num_labels)[bbox_first_token_mask], labels.view(-1)[bbox_first_token_mask]
  569. )
  570. else:
  571. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  572. return TokenClassifierOutput(
  573. loss=loss,
  574. logits=logits,
  575. hidden_states=outputs.hidden_states,
  576. attentions=outputs.attentions,
  577. )
  578. @auto_docstring(
  579. custom_intro="""
  580. Bros Model with a token classification head on top (initial_token_layers and subsequent_token_layer on top of the
  581. hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. The initial_token_classifier is used to
  582. predict the first token of each entity, and the subsequent_token_classifier is used to predict the subsequent
  583. tokens within an entity. Compared to BrosForTokenClassification, this model is more robust to serialization errors
  584. since it predicts next token from one token.
  585. """
  586. )
  587. class BrosSpadeEEForTokenClassification(BrosPreTrainedModel):
  588. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  589. def __init__(self, config):
  590. super().__init__(config)
  591. self.config = config
  592. self.num_labels = config.num_labels
  593. self.n_relations = config.n_relations
  594. self.backbone_hidden_size = config.hidden_size
  595. self.bros = BrosModel(config)
  596. classifier_dropout = (
  597. config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob
  598. )
  599. # Initial token classification for Entity Extraction (NER)
  600. self.initial_token_classifier = nn.Sequential(
  601. nn.Dropout(classifier_dropout),
  602. nn.Linear(config.hidden_size, config.hidden_size),
  603. nn.Dropout(classifier_dropout),
  604. nn.Linear(config.hidden_size, config.num_labels),
  605. )
  606. # Subsequent token classification for Entity Extraction (NER)
  607. self.subsequent_token_classifier = BrosRelationExtractor(config)
  608. self.post_init()
  609. @can_return_tuple
  610. @auto_docstring
  611. def forward(
  612. self,
  613. input_ids: torch.Tensor | None = None,
  614. bbox: torch.Tensor | None = None,
  615. attention_mask: torch.Tensor | None = None,
  616. bbox_first_token_mask: torch.Tensor | None = None,
  617. token_type_ids: torch.Tensor | None = None,
  618. position_ids: torch.Tensor | None = None,
  619. inputs_embeds: torch.Tensor | None = None,
  620. initial_token_labels: torch.Tensor | None = None,
  621. subsequent_token_labels: torch.Tensor | None = None,
  622. **kwargs: Unpack[TransformersKwargs],
  623. ) -> tuple[torch.Tensor] | BrosSpadeOutput:
  624. r"""
  625. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  626. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  627. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  628. bounding box.
  629. bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  630. Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
  631. - 1 for tokens that are **not masked**,
  632. - 0 for tokens that are **masked**.
  633. initial_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  634. Labels for the initial token classification.
  635. subsequent_token_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  636. Labels for the subsequent token classification.
  637. Examples:
  638. ```python
  639. >>> import torch
  640. >>> from transformers import BrosProcessor, BrosSpadeEEForTokenClassification
  641. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  642. >>> model = BrosSpadeEEForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
  643. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  644. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  645. >>> encoding["bbox"] = bbox
  646. >>> outputs = model(**encoding)
  647. ```"""
  648. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bros(
  649. input_ids=input_ids,
  650. bbox=bbox,
  651. attention_mask=attention_mask,
  652. token_type_ids=token_type_ids,
  653. position_ids=position_ids,
  654. inputs_embeds=inputs_embeds,
  655. **kwargs,
  656. )
  657. last_hidden_states = outputs[0]
  658. last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
  659. initial_token_logits = self.initial_token_classifier(last_hidden_states).transpose(0, 1).contiguous()
  660. subsequent_token_logits = self.subsequent_token_classifier(last_hidden_states, last_hidden_states).squeeze(0)
  661. # make subsequent token (sequence token classification) mask
  662. inv_attention_mask = 1 - attention_mask
  663. batch_size, max_seq_length = inv_attention_mask.shape
  664. device = inv_attention_mask.device
  665. invalid_token_mask = torch.cat([inv_attention_mask, torch.zeros([batch_size, 1]).to(device)], axis=1).bool()
  666. subsequent_token_logits = subsequent_token_logits.masked_fill(
  667. invalid_token_mask[:, None, :], torch.finfo(subsequent_token_logits.dtype).min
  668. )
  669. self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
  670. subsequent_token_logits = subsequent_token_logits.masked_fill(
  671. self_token_mask[None, :, :], torch.finfo(subsequent_token_logits.dtype).min
  672. )
  673. subsequent_token_mask = attention_mask.view(-1).bool()
  674. loss = None
  675. if initial_token_labels is not None and subsequent_token_labels is not None:
  676. loss_fct = CrossEntropyLoss()
  677. # get initial token loss
  678. initial_token_labels = initial_token_labels.view(-1)
  679. if bbox_first_token_mask is not None:
  680. bbox_first_token_mask = bbox_first_token_mask.view(-1)
  681. initial_token_loss = loss_fct(
  682. initial_token_logits.view(-1, self.num_labels)[bbox_first_token_mask],
  683. initial_token_labels[bbox_first_token_mask],
  684. )
  685. else:
  686. initial_token_loss = loss_fct(initial_token_logits.view(-1, self.num_labels), initial_token_labels)
  687. subsequent_token_labels = subsequent_token_labels.view(-1)
  688. subsequent_token_loss = loss_fct(
  689. subsequent_token_logits.view(-1, max_seq_length + 1)[subsequent_token_mask],
  690. subsequent_token_labels[subsequent_token_mask],
  691. )
  692. loss = initial_token_loss + subsequent_token_loss
  693. return BrosSpadeOutput(
  694. loss=loss,
  695. initial_token_logits=initial_token_logits,
  696. subsequent_token_logits=subsequent_token_logits,
  697. hidden_states=outputs.hidden_states,
  698. attentions=outputs.attentions,
  699. )
  700. @auto_docstring(
  701. custom_intro="""
  702. Bros Model with a token classification head on top (a entity_linker layer on top of the hidden-states output) e.g.
  703. for Entity-Linking. The entity_linker is used to predict intra-entity links (one entity to another entity).
  704. """
  705. )
  706. class BrosSpadeELForTokenClassification(BrosPreTrainedModel):
  707. _keys_to_ignore_on_load_unexpected = [r"pooler"]
  708. def __init__(self, config):
  709. super().__init__(config)
  710. self.config = config
  711. self.num_labels = config.num_labels
  712. self.n_relations = config.n_relations
  713. self.backbone_hidden_size = config.hidden_size
  714. self.bros = BrosModel(config)
  715. (config.classifier_dropout if hasattr(config, "classifier_dropout") else config.hidden_dropout_prob)
  716. self.entity_linker = BrosRelationExtractor(config)
  717. self.post_init()
  718. @can_return_tuple
  719. @auto_docstring
  720. def forward(
  721. self,
  722. input_ids: torch.Tensor | None = None,
  723. bbox: torch.Tensor | None = None,
  724. attention_mask: torch.Tensor | None = None,
  725. bbox_first_token_mask: torch.Tensor | None = None,
  726. token_type_ids: torch.Tensor | None = None,
  727. position_ids: torch.Tensor | None = None,
  728. inputs_embeds: torch.Tensor | None = None,
  729. labels: torch.Tensor | None = None,
  730. **kwargs: Unpack[TransformersKwargs],
  731. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  732. r"""
  733. bbox ('torch.FloatTensor' of shape '(batch_size, num_boxes, 4)'):
  734. Bounding box coordinates for each token in the input sequence. Each bounding box is a list of four values
  735. (x1, y1, x2, y2), where (x1, y1) is the top left corner, and (x2, y2) is the bottom right corner of the
  736. bounding box.
  737. bbox_first_token_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  738. Mask to indicate the first token of each bounding box. Mask values selected in `[0, 1]`:
  739. - 1 for tokens that are **not masked**,
  740. - 0 for tokens that are **masked**.
  741. Examples:
  742. ```python
  743. >>> import torch
  744. >>> from transformers import BrosProcessor, BrosSpadeELForTokenClassification
  745. >>> processor = BrosProcessor.from_pretrained("jinho8345/bros-base-uncased")
  746. >>> model = BrosSpadeELForTokenClassification.from_pretrained("jinho8345/bros-base-uncased")
  747. >>> encoding = processor("Hello, my dog is cute", add_special_tokens=False, return_tensors="pt")
  748. >>> bbox = torch.tensor([[[0, 0, 1, 1]]]).repeat(1, encoding["input_ids"].shape[-1], 1)
  749. >>> encoding["bbox"] = bbox
  750. >>> outputs = model(**encoding)
  751. ```"""
  752. outputs: BaseModelOutputWithPoolingAndCrossAttentions = self.bros(
  753. input_ids=input_ids,
  754. bbox=bbox,
  755. attention_mask=attention_mask,
  756. token_type_ids=token_type_ids,
  757. position_ids=position_ids,
  758. inputs_embeds=inputs_embeds,
  759. **kwargs,
  760. )
  761. last_hidden_states = outputs[0]
  762. last_hidden_states = last_hidden_states.transpose(0, 1).contiguous()
  763. logits = self.entity_linker(last_hidden_states, last_hidden_states).squeeze(0)
  764. loss = None
  765. if labels is not None:
  766. loss_fct = CrossEntropyLoss()
  767. batch_size, max_seq_length = attention_mask.shape
  768. device = attention_mask.device
  769. self_token_mask = torch.eye(max_seq_length, max_seq_length + 1).to(device=device, dtype=torch.bool)
  770. mask = bbox_first_token_mask.view(-1)
  771. bbox_first_token_mask = torch.cat(
  772. [
  773. ~bbox_first_token_mask,
  774. torch.zeros([batch_size, 1], dtype=torch.bool, device=device),
  775. ],
  776. axis=1,
  777. )
  778. logits = logits.masked_fill(bbox_first_token_mask[:, None, :], torch.finfo(logits.dtype).min)
  779. logits = logits.masked_fill(self_token_mask[None, :, :], torch.finfo(logits.dtype).min)
  780. loss = loss_fct(logits.view(-1, max_seq_length + 1)[mask], labels.view(-1)[mask])
  781. return TokenClassifierOutput(
  782. loss=loss,
  783. logits=logits,
  784. hidden_states=outputs.hidden_states,
  785. attentions=outputs.attentions,
  786. )
  787. __all__ = [
  788. "BrosPreTrainedModel",
  789. "BrosModel",
  790. "BrosForTokenClassification",
  791. "BrosSpadeEEForTokenClassification",
  792. "BrosSpadeELForTokenClassification",
  793. ]