modeling_lilt.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995
  1. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LiLT model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. QuestionAnsweringModelOutput,
  26. SequenceClassifierOutput,
  27. TokenClassifierOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward
  31. from ...utils import auto_docstring, logging
  32. from .configuration_lilt import LiltConfig
  33. logger = logging.get_logger(__name__)
  34. class LiltTextEmbeddings(nn.Module):
  35. def __init__(self, config):
  36. super().__init__()
  37. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  38. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  39. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  40. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  41. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  42. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  43. self.register_buffer(
  44. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  45. )
  46. # End copy
  47. self.padding_idx = config.pad_token_id
  48. self.position_embeddings = nn.Embedding(
  49. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  50. )
  51. def forward(
  52. self,
  53. input_ids=None,
  54. token_type_ids=None,
  55. position_ids=None,
  56. inputs_embeds=None,
  57. ):
  58. if position_ids is None:
  59. if input_ids is not None:
  60. # Create the position ids from the input token ids. Any padded tokens remain padded.
  61. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
  62. input_ids.device
  63. )
  64. else:
  65. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  66. if input_ids is not None:
  67. input_shape = input_ids.size()
  68. else:
  69. input_shape = inputs_embeds.size()[:-1]
  70. if token_type_ids is None:
  71. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  72. if inputs_embeds is None:
  73. inputs_embeds = self.word_embeddings(input_ids)
  74. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  75. embeddings = inputs_embeds + token_type_embeddings
  76. position_embeddings = self.position_embeddings(position_ids)
  77. embeddings += position_embeddings
  78. embeddings = self.LayerNorm(embeddings)
  79. embeddings = self.dropout(embeddings)
  80. return embeddings, position_ids
  81. def create_position_ids_from_input_ids(self, input_ids, padding_idx):
  82. """
  83. Args:
  84. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  85. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  86. x: torch.Tensor x:
  87. Returns: torch.Tensor
  88. """
  89. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  90. mask = input_ids.ne(padding_idx).int()
  91. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  92. return incremental_indices.long() + padding_idx
  93. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  94. """
  95. Args:
  96. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.:
  97. inputs_embeds: torch.Tensor
  98. Returns: torch.Tensor
  99. """
  100. input_shape = inputs_embeds.size()[:-1]
  101. sequence_length = input_shape[1]
  102. position_ids = torch.arange(
  103. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  104. )
  105. return position_ids.unsqueeze(0).expand(input_shape)
  106. class LiltLayoutEmbeddings(nn.Module):
  107. def __init__(self, config):
  108. super().__init__()
  109. # we divide the hidden_size by 6 here as there are 6 different layout embeddings,
  110. # namely left_position, upper_position, right_position, lower_position, height, width
  111. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  112. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  113. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  114. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size // 6)
  115. self.padding_idx = config.pad_token_id
  116. self.box_position_embeddings = nn.Embedding(
  117. config.max_position_embeddings,
  118. config.hidden_size // config.channel_shrink_ratio,
  119. padding_idx=self.padding_idx,
  120. )
  121. self.box_linear_embeddings = nn.Linear(
  122. in_features=config.hidden_size, out_features=config.hidden_size // config.channel_shrink_ratio
  123. )
  124. self.LayerNorm = nn.LayerNorm(config.hidden_size // config.channel_shrink_ratio, eps=config.layer_norm_eps)
  125. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  126. def forward(self, bbox=None, position_ids=None):
  127. try:
  128. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  129. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  130. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  131. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  132. except IndexError as e:
  133. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  134. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  135. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  136. spatial_position_embeddings = torch.cat(
  137. [
  138. left_position_embeddings,
  139. upper_position_embeddings,
  140. right_position_embeddings,
  141. lower_position_embeddings,
  142. h_position_embeddings,
  143. w_position_embeddings,
  144. ],
  145. dim=-1,
  146. )
  147. spatial_position_embeddings = self.box_linear_embeddings(spatial_position_embeddings)
  148. box_position_embeddings = self.box_position_embeddings(position_ids)
  149. spatial_position_embeddings = spatial_position_embeddings + box_position_embeddings
  150. spatial_position_embeddings = self.LayerNorm(spatial_position_embeddings)
  151. spatial_position_embeddings = self.dropout(spatial_position_embeddings)
  152. return spatial_position_embeddings
  153. class LiltSelfAttention(nn.Module):
  154. def __init__(self, config, layer_idx=None):
  155. super().__init__()
  156. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  157. raise ValueError(
  158. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  159. f"heads ({config.num_attention_heads})"
  160. )
  161. self.num_attention_heads = config.num_attention_heads
  162. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  163. self.all_head_size = self.num_attention_heads * self.attention_head_size
  164. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  165. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  166. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  167. self.layout_query = nn.Linear(
  168. config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
  169. )
  170. self.layout_key = nn.Linear(
  171. config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
  172. )
  173. self.layout_value = nn.Linear(
  174. config.hidden_size // config.channel_shrink_ratio, self.all_head_size // config.channel_shrink_ratio
  175. )
  176. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  177. self.channel_shrink_ratio = config.channel_shrink_ratio
  178. self.layer_idx = layer_idx
  179. def transpose_for_scores(self, x, r=1):
  180. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size // r)
  181. x = x.view(*new_x_shape)
  182. return x.permute(0, 2, 1, 3)
  183. def forward(
  184. self,
  185. hidden_states,
  186. layout_inputs,
  187. attention_mask=None,
  188. output_attentions=False,
  189. ):
  190. layout_value_layer = self.transpose_for_scores(self.layout_value(layout_inputs), r=self.channel_shrink_ratio)
  191. layout_key_layer = self.transpose_for_scores(self.layout_key(layout_inputs), r=self.channel_shrink_ratio)
  192. layout_query_layer = self.transpose_for_scores(self.layout_query(layout_inputs), r=self.channel_shrink_ratio)
  193. mixed_query_layer = self.query(hidden_states)
  194. key_layer = self.transpose_for_scores(self.key(hidden_states))
  195. value_layer = self.transpose_for_scores(self.value(hidden_states))
  196. query_layer = self.transpose_for_scores(mixed_query_layer)
  197. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  198. layout_attention_scores = torch.matmul(layout_query_layer, layout_key_layer.transpose(-1, -2))
  199. tmp_attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  200. tmp_layout_attention_scores = layout_attention_scores / math.sqrt(
  201. self.attention_head_size // self.channel_shrink_ratio
  202. )
  203. attention_scores = tmp_attention_scores + tmp_layout_attention_scores
  204. layout_attention_scores = tmp_layout_attention_scores + tmp_attention_scores
  205. if attention_mask is not None:
  206. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  207. layout_attention_scores = layout_attention_scores + attention_mask
  208. # Normalize the attention scores to probabilities.
  209. layout_attention_probs = nn.Softmax(dim=-1)(layout_attention_scores)
  210. # This is actually dropping out entire tokens to attend to, which might
  211. # seem a bit unusual, but is taken from the original Transformer paper.
  212. layout_attention_probs = self.dropout(layout_attention_probs)
  213. layout_context_layer = torch.matmul(layout_attention_probs, layout_value_layer)
  214. layout_context_layer = layout_context_layer.permute(0, 2, 1, 3).contiguous()
  215. new_context_layer_shape = layout_context_layer.size()[:-2] + (self.all_head_size // self.channel_shrink_ratio,)
  216. layout_context_layer = layout_context_layer.view(*new_context_layer_shape)
  217. if attention_mask is not None:
  218. # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
  219. attention_scores = attention_scores + attention_mask
  220. # Normalize the attention scores to probabilities.
  221. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  222. # This is actually dropping out entire tokens to attend to, which might
  223. # seem a bit unusual, but is taken from the original Transformer paper.
  224. attention_probs = self.dropout(attention_probs)
  225. context_layer = torch.matmul(attention_probs, value_layer)
  226. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  227. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  228. context_layer = context_layer.view(*new_context_layer_shape)
  229. outputs = (context_layer, layout_context_layer)
  230. if output_attentions:
  231. outputs = outputs + (attention_probs,)
  232. return outputs
  233. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  234. class LiltSelfOutput(nn.Module):
  235. def __init__(self, config):
  236. super().__init__()
  237. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  238. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  239. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  240. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  241. hidden_states = self.dense(hidden_states)
  242. hidden_states = self.dropout(hidden_states)
  243. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  244. return hidden_states
  245. class LiltAttention(nn.Module):
  246. def __init__(self, config, layer_idx=None):
  247. super().__init__()
  248. self.self = LiltSelfAttention(config, layer_idx=layer_idx)
  249. self.output = LiltSelfOutput(config)
  250. ori_hidden_size = config.hidden_size
  251. config.hidden_size = config.hidden_size // config.channel_shrink_ratio
  252. self.layout_output = LiltSelfOutput(config)
  253. config.hidden_size = ori_hidden_size
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. layout_inputs: torch.Tensor,
  258. attention_mask: torch.FloatTensor | None = None,
  259. output_attentions: bool | None = False,
  260. ) -> tuple[torch.Tensor]:
  261. self_outputs = self.self(
  262. hidden_states,
  263. layout_inputs,
  264. attention_mask,
  265. output_attentions,
  266. )
  267. attention_output = self.output(self_outputs[0], hidden_states)
  268. layout_attention_output = self.layout_output(self_outputs[1], layout_inputs)
  269. outputs = (attention_output, layout_attention_output) + self_outputs[2:] # add attentions if we output them
  270. return outputs
  271. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  272. class LiltIntermediate(nn.Module):
  273. def __init__(self, config):
  274. super().__init__()
  275. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  276. if isinstance(config.hidden_act, str):
  277. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  278. else:
  279. self.intermediate_act_fn = config.hidden_act
  280. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  281. hidden_states = self.dense(hidden_states)
  282. hidden_states = self.intermediate_act_fn(hidden_states)
  283. return hidden_states
  284. # Copied from transformers.models.bert.modeling_bert.BertOutput
  285. class LiltOutput(nn.Module):
  286. def __init__(self, config):
  287. super().__init__()
  288. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  289. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  290. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  291. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  292. hidden_states = self.dense(hidden_states)
  293. hidden_states = self.dropout(hidden_states)
  294. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  295. return hidden_states
  296. class LiltLayer(GradientCheckpointingLayer):
  297. def __init__(self, config, layer_idx=None):
  298. super().__init__()
  299. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  300. self.seq_len_dim = 1
  301. self.attention = LiltAttention(config, layer_idx=layer_idx)
  302. self.intermediate = LiltIntermediate(config)
  303. self.output = LiltOutput(config)
  304. ori_hidden_size = config.hidden_size
  305. ori_intermediate_size = config.intermediate_size
  306. config.hidden_size = config.hidden_size // config.channel_shrink_ratio
  307. config.intermediate_size = config.intermediate_size // config.channel_shrink_ratio
  308. self.layout_intermediate = LiltIntermediate(config)
  309. self.layout_output = LiltOutput(config)
  310. config.hidden_size = ori_hidden_size
  311. config.intermediate_size = ori_intermediate_size
  312. def forward(
  313. self,
  314. hidden_states: torch.Tensor,
  315. layout_inputs: torch.Tensor,
  316. attention_mask: torch.FloatTensor | None = None,
  317. output_attentions: bool | None = False,
  318. ) -> tuple[torch.Tensor]:
  319. self_attention_outputs = self.attention(
  320. hidden_states,
  321. layout_inputs,
  322. attention_mask,
  323. output_attentions=output_attentions,
  324. )
  325. attention_output = self_attention_outputs[0]
  326. layout_attention_output = self_attention_outputs[1]
  327. outputs = self_attention_outputs[2:] # add self attentions if we output attention weights
  328. layer_output = apply_chunking_to_forward(
  329. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  330. )
  331. layout_layer_output = apply_chunking_to_forward(
  332. self.layout_feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, layout_attention_output
  333. )
  334. outputs = (layer_output, layout_layer_output) + outputs
  335. return outputs
  336. # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
  337. def feed_forward_chunk(self, attention_output):
  338. intermediate_output = self.intermediate(attention_output)
  339. layer_output = self.output(intermediate_output, attention_output)
  340. return layer_output
  341. def layout_feed_forward_chunk(self, attention_output):
  342. intermediate_output = self.layout_intermediate(attention_output)
  343. layer_output = self.layout_output(intermediate_output, attention_output)
  344. return layer_output
  345. class LiltEncoder(nn.Module):
  346. def __init__(self, config):
  347. super().__init__()
  348. self.config = config
  349. self.layer = nn.ModuleList([LiltLayer(config) for _ in range(config.num_hidden_layers)])
  350. def forward(
  351. self,
  352. hidden_states: torch.Tensor,
  353. layout_inputs: torch.Tensor,
  354. attention_mask: torch.FloatTensor | None = None,
  355. output_attentions: bool | None = False,
  356. output_hidden_states: bool | None = False,
  357. return_dict: bool | None = True,
  358. ) -> tuple[torch.Tensor] | BaseModelOutput:
  359. all_hidden_states = () if output_hidden_states else None
  360. all_self_attentions = () if output_attentions else None
  361. for i, layer_module in enumerate(self.layer):
  362. if output_hidden_states:
  363. all_hidden_states = all_hidden_states + (hidden_states,)
  364. layer_outputs = layer_module(
  365. hidden_states,
  366. layout_inputs,
  367. attention_mask,
  368. output_attentions,
  369. )
  370. hidden_states = layer_outputs[0]
  371. layout_inputs = layer_outputs[1]
  372. if output_attentions:
  373. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  374. if output_hidden_states:
  375. all_hidden_states = all_hidden_states + (hidden_states,)
  376. if not return_dict:
  377. return tuple(
  378. v
  379. for v in [
  380. hidden_states,
  381. all_hidden_states,
  382. all_self_attentions,
  383. ]
  384. if v is not None
  385. )
  386. return BaseModelOutput(
  387. last_hidden_state=hidden_states,
  388. hidden_states=all_hidden_states,
  389. attentions=all_self_attentions,
  390. )
  391. # Copied from transformers.models.bert.modeling_bert.BertPooler
  392. class LiltPooler(nn.Module):
  393. def __init__(self, config):
  394. super().__init__()
  395. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  396. self.activation = nn.Tanh()
  397. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  398. # We "pool" the model by simply taking the hidden state corresponding
  399. # to the first token.
  400. first_token_tensor = hidden_states[:, 0]
  401. pooled_output = self.dense(first_token_tensor)
  402. pooled_output = self.activation(pooled_output)
  403. return pooled_output
  404. @auto_docstring
  405. class LiltPreTrainedModel(PreTrainedModel):
  406. config: LiltConfig
  407. base_model_prefix = "lilt"
  408. supports_gradient_checkpointing = True
  409. _no_split_modules = []
  410. def _init_weights(self, module):
  411. super()._init_weights(module)
  412. if isinstance(module, LiltTextEmbeddings):
  413. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  414. @auto_docstring
  415. class LiltModel(LiltPreTrainedModel):
  416. def __init__(self, config, add_pooling_layer=True):
  417. r"""
  418. add_pooling_layer (bool, *optional*, defaults to `True`):
  419. Whether to add a pooling layer
  420. """
  421. super().__init__(config)
  422. self.config = config
  423. self.embeddings = LiltTextEmbeddings(config)
  424. self.layout_embeddings = LiltLayoutEmbeddings(config)
  425. self.encoder = LiltEncoder(config)
  426. self.pooler = LiltPooler(config) if add_pooling_layer else None
  427. # Initialize weights and apply final processing
  428. self.post_init()
  429. def get_input_embeddings(self):
  430. return self.embeddings.word_embeddings
  431. def set_input_embeddings(self, value):
  432. self.embeddings.word_embeddings = value
  433. @auto_docstring
  434. def forward(
  435. self,
  436. input_ids: torch.Tensor | None = None,
  437. bbox: torch.Tensor | None = None,
  438. attention_mask: torch.Tensor | None = None,
  439. token_type_ids: torch.Tensor | None = None,
  440. position_ids: torch.Tensor | None = None,
  441. inputs_embeds: torch.Tensor | None = None,
  442. output_attentions: bool | None = None,
  443. output_hidden_states: bool | None = None,
  444. return_dict: bool | None = None,
  445. **kwargs,
  446. ) -> tuple[torch.Tensor] | BaseModelOutputWithPooling:
  447. r"""
  448. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  449. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  450. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  451. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  452. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  453. Examples:
  454. ```python
  455. >>> from transformers import AutoTokenizer, AutoModel
  456. >>> from datasets import load_dataset
  457. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  458. >>> model = AutoModel.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  459. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  460. >>> example = dataset[0]
  461. >>> words = example["tokens"]
  462. >>> boxes = example["bboxes"]
  463. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  464. >>> outputs = model(**encoding)
  465. >>> last_hidden_states = outputs.last_hidden_state
  466. ```"""
  467. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  468. output_hidden_states = (
  469. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  470. )
  471. return_dict = return_dict if return_dict is not None else self.config.return_dict
  472. if input_ids is not None and inputs_embeds is not None:
  473. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  474. elif input_ids is not None:
  475. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  476. input_shape = input_ids.size()
  477. elif inputs_embeds is not None:
  478. input_shape = inputs_embeds.size()[:-1]
  479. else:
  480. raise ValueError("You have to specify either input_ids or inputs_embeds")
  481. batch_size, seq_length = input_shape
  482. device = input_ids.device if input_ids is not None else inputs_embeds.device
  483. if bbox is None:
  484. bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
  485. if attention_mask is None:
  486. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  487. if token_type_ids is None:
  488. if hasattr(self.embeddings, "token_type_ids"):
  489. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  490. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  491. token_type_ids = buffered_token_type_ids_expanded
  492. else:
  493. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  494. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  495. # ourselves in which case we just need to make it broadcastable to all heads.
  496. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  497. embedding_output, position_ids = self.embeddings(
  498. input_ids=input_ids,
  499. position_ids=position_ids,
  500. token_type_ids=token_type_ids,
  501. inputs_embeds=inputs_embeds,
  502. )
  503. layout_embedding_output = self.layout_embeddings(bbox=bbox, position_ids=position_ids)
  504. encoder_outputs = self.encoder(
  505. embedding_output,
  506. layout_embedding_output,
  507. attention_mask=extended_attention_mask,
  508. output_attentions=output_attentions,
  509. output_hidden_states=output_hidden_states,
  510. return_dict=return_dict,
  511. )
  512. sequence_output = encoder_outputs[0]
  513. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  514. if not return_dict:
  515. return (sequence_output, pooled_output) + encoder_outputs[1:]
  516. return BaseModelOutputWithPooling(
  517. last_hidden_state=sequence_output,
  518. pooler_output=pooled_output,
  519. hidden_states=encoder_outputs.hidden_states,
  520. attentions=encoder_outputs.attentions,
  521. )
  522. @auto_docstring(
  523. custom_intro="""
  524. LiLT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  525. output) e.g. for GLUE tasks.
  526. """
  527. )
  528. class LiltForSequenceClassification(LiltPreTrainedModel):
  529. # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Lilt, roberta->lilt
  530. def __init__(self, config):
  531. super().__init__(config)
  532. self.num_labels = config.num_labels
  533. self.config = config
  534. self.lilt = LiltModel(config, add_pooling_layer=False)
  535. self.classifier = LiltClassificationHead(config)
  536. # Initialize weights and apply final processing
  537. self.post_init()
  538. @auto_docstring
  539. def forward(
  540. self,
  541. input_ids: torch.LongTensor | None = None,
  542. bbox: torch.Tensor | None = None,
  543. attention_mask: torch.FloatTensor | None = None,
  544. token_type_ids: torch.LongTensor | None = None,
  545. position_ids: torch.LongTensor | None = None,
  546. inputs_embeds: torch.FloatTensor | None = None,
  547. labels: torch.LongTensor | None = None,
  548. output_attentions: bool | None = None,
  549. output_hidden_states: bool | None = None,
  550. return_dict: bool | None = None,
  551. **kwargs,
  552. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  553. r"""
  554. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  555. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  556. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  557. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  558. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  559. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  560. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  561. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  562. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  563. Examples:
  564. ```python
  565. >>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
  566. >>> from datasets import load_dataset
  567. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  568. >>> model = AutoModelForSequenceClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  569. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  570. >>> example = dataset[0]
  571. >>> words = example["tokens"]
  572. >>> boxes = example["bboxes"]
  573. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  574. >>> outputs = model(**encoding)
  575. >>> predicted_class_idx = outputs.logits.argmax(-1).item()
  576. >>> predicted_class = model.config.id2label[predicted_class_idx]
  577. ```"""
  578. return_dict = return_dict if return_dict is not None else self.config.return_dict
  579. outputs = self.lilt(
  580. input_ids,
  581. bbox=bbox,
  582. attention_mask=attention_mask,
  583. token_type_ids=token_type_ids,
  584. position_ids=position_ids,
  585. inputs_embeds=inputs_embeds,
  586. output_attentions=output_attentions,
  587. output_hidden_states=output_hidden_states,
  588. return_dict=return_dict,
  589. )
  590. sequence_output = outputs[0]
  591. logits = self.classifier(sequence_output)
  592. loss = None
  593. if labels is not None:
  594. # move labels to correct device
  595. labels = labels.to(logits.device)
  596. if self.config.problem_type is None:
  597. if self.num_labels == 1:
  598. self.config.problem_type = "regression"
  599. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  600. self.config.problem_type = "single_label_classification"
  601. else:
  602. self.config.problem_type = "multi_label_classification"
  603. if self.config.problem_type == "regression":
  604. loss_fct = MSELoss()
  605. if self.num_labels == 1:
  606. loss = loss_fct(logits.squeeze(), labels.squeeze())
  607. else:
  608. loss = loss_fct(logits, labels)
  609. elif self.config.problem_type == "single_label_classification":
  610. loss_fct = CrossEntropyLoss()
  611. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  612. elif self.config.problem_type == "multi_label_classification":
  613. loss_fct = BCEWithLogitsLoss()
  614. loss = loss_fct(logits, labels)
  615. if not return_dict:
  616. output = (logits,) + outputs[2:]
  617. return ((loss,) + output) if loss is not None else output
  618. return SequenceClassifierOutput(
  619. loss=loss,
  620. logits=logits,
  621. hidden_states=outputs.hidden_states,
  622. attentions=outputs.attentions,
  623. )
  624. @auto_docstring
  625. class LiltForTokenClassification(LiltPreTrainedModel):
  626. # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Lilt, roberta->lilt
  627. def __init__(self, config):
  628. super().__init__(config)
  629. self.num_labels = config.num_labels
  630. self.lilt = LiltModel(config, add_pooling_layer=False)
  631. classifier_dropout = (
  632. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  633. )
  634. self.dropout = nn.Dropout(classifier_dropout)
  635. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  636. # Initialize weights and apply final processing
  637. self.post_init()
  638. @auto_docstring
  639. def forward(
  640. self,
  641. input_ids: torch.LongTensor | None = None,
  642. bbox: torch.LongTensor | None = None,
  643. attention_mask: torch.FloatTensor | None = None,
  644. token_type_ids: torch.LongTensor | None = None,
  645. position_ids: torch.LongTensor | None = None,
  646. inputs_embeds: torch.FloatTensor | None = None,
  647. labels: torch.LongTensor | None = None,
  648. output_attentions: bool | None = None,
  649. output_hidden_states: bool | None = None,
  650. return_dict: bool | None = None,
  651. **kwargs,
  652. ) -> tuple[torch.Tensor] | TokenClassifierOutput:
  653. r"""
  654. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  655. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  656. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  657. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  658. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  659. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  660. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  661. Examples:
  662. ```python
  663. >>> from transformers import AutoTokenizer, AutoModelForTokenClassification
  664. >>> from datasets import load_dataset
  665. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  666. >>> model = AutoModelForTokenClassification.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  667. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  668. >>> example = dataset[0]
  669. >>> words = example["tokens"]
  670. >>> boxes = example["bboxes"]
  671. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  672. >>> outputs = model(**encoding)
  673. >>> predicted_class_indices = outputs.logits.argmax(-1)
  674. ```"""
  675. return_dict = return_dict if return_dict is not None else self.config.return_dict
  676. outputs = self.lilt(
  677. input_ids,
  678. bbox=bbox,
  679. attention_mask=attention_mask,
  680. token_type_ids=token_type_ids,
  681. position_ids=position_ids,
  682. inputs_embeds=inputs_embeds,
  683. output_attentions=output_attentions,
  684. output_hidden_states=output_hidden_states,
  685. return_dict=return_dict,
  686. )
  687. sequence_output = outputs[0]
  688. sequence_output = self.dropout(sequence_output)
  689. logits = self.classifier(sequence_output)
  690. loss = None
  691. if labels is not None:
  692. # move labels to correct device
  693. labels = labels.to(logits.device)
  694. loss_fct = CrossEntropyLoss()
  695. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  696. if not return_dict:
  697. output = (logits,) + outputs[2:]
  698. return ((loss,) + output) if loss is not None else output
  699. return TokenClassifierOutput(
  700. loss=loss,
  701. logits=logits,
  702. hidden_states=outputs.hidden_states,
  703. attentions=outputs.attentions,
  704. )
  705. # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->Lilt
  706. class LiltClassificationHead(nn.Module):
  707. """Head for sentence-level classification tasks."""
  708. def __init__(self, config):
  709. super().__init__()
  710. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  711. classifier_dropout = (
  712. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  713. )
  714. self.dropout = nn.Dropout(classifier_dropout)
  715. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  716. def forward(self, features, **kwargs):
  717. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  718. x = self.dropout(x)
  719. x = self.dense(x)
  720. x = torch.tanh(x)
  721. x = self.dropout(x)
  722. x = self.out_proj(x)
  723. return x
  724. @auto_docstring
  725. class LiltForQuestionAnswering(LiltPreTrainedModel):
  726. # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Lilt, roberta->lilt
  727. def __init__(self, config):
  728. super().__init__(config)
  729. self.num_labels = config.num_labels
  730. self.lilt = LiltModel(config, add_pooling_layer=False)
  731. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  732. # Initialize weights and apply final processing
  733. self.post_init()
  734. @auto_docstring
  735. def forward(
  736. self,
  737. input_ids: torch.LongTensor | None = None,
  738. bbox: torch.LongTensor | None = None,
  739. attention_mask: torch.FloatTensor | None = None,
  740. token_type_ids: torch.LongTensor | None = None,
  741. position_ids: torch.LongTensor | None = None,
  742. inputs_embeds: torch.FloatTensor | None = None,
  743. start_positions: torch.LongTensor | None = None,
  744. end_positions: torch.LongTensor | None = None,
  745. output_attentions: bool | None = None,
  746. output_hidden_states: bool | None = None,
  747. return_dict: bool | None = None,
  748. **kwargs,
  749. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  750. r"""
  751. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  752. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  753. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  754. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  755. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  756. Examples:
  757. ```python
  758. >>> from transformers import AutoTokenizer, AutoModelForQuestionAnswering
  759. >>> from datasets import load_dataset
  760. >>> tokenizer = AutoTokenizer.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  761. >>> model = AutoModelForQuestionAnswering.from_pretrained("SCUT-DLVCLab/lilt-roberta-en-base")
  762. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  763. >>> example = dataset[0]
  764. >>> words = example["tokens"]
  765. >>> boxes = example["bboxes"]
  766. >>> encoding = tokenizer(words, boxes=boxes, return_tensors="pt")
  767. >>> outputs = model(**encoding)
  768. >>> answer_start_index = outputs.start_logits.argmax()
  769. >>> answer_end_index = outputs.end_logits.argmax()
  770. >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
  771. >>> predicted_answer = tokenizer.decode(predict_answer_tokens)
  772. ```"""
  773. return_dict = return_dict if return_dict is not None else self.config.return_dict
  774. outputs = self.lilt(
  775. input_ids,
  776. bbox=bbox,
  777. attention_mask=attention_mask,
  778. token_type_ids=token_type_ids,
  779. position_ids=position_ids,
  780. inputs_embeds=inputs_embeds,
  781. output_attentions=output_attentions,
  782. output_hidden_states=output_hidden_states,
  783. return_dict=return_dict,
  784. )
  785. sequence_output = outputs[0]
  786. logits = self.qa_outputs(sequence_output)
  787. start_logits, end_logits = logits.split(1, dim=-1)
  788. start_logits = start_logits.squeeze(-1).contiguous()
  789. end_logits = end_logits.squeeze(-1).contiguous()
  790. total_loss = None
  791. if start_positions is not None and end_positions is not None:
  792. # If we are on multi-GPU, split add a dimension
  793. if len(start_positions.size()) > 1:
  794. start_positions = start_positions.squeeze(-1)
  795. if len(end_positions.size()) > 1:
  796. end_positions = end_positions.squeeze(-1)
  797. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  798. ignored_index = start_logits.size(1)
  799. start_positions = start_positions.clamp(0, ignored_index)
  800. end_positions = end_positions.clamp(0, ignored_index)
  801. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  802. start_loss = loss_fct(start_logits, start_positions)
  803. end_loss = loss_fct(end_logits, end_positions)
  804. total_loss = (start_loss + end_loss) / 2
  805. if not return_dict:
  806. output = (start_logits, end_logits) + outputs[2:]
  807. return ((total_loss,) + output) if total_loss is not None else output
  808. return QuestionAnsweringModelOutput(
  809. loss=total_loss,
  810. start_logits=start_logits,
  811. end_logits=end_logits,
  812. hidden_states=outputs.hidden_states,
  813. attentions=outputs.attentions,
  814. )
  815. __all__ = [
  816. "LiltForQuestionAnswering",
  817. "LiltForSequenceClassification",
  818. "LiltForTokenClassification",
  819. "LiltModel",
  820. "LiltPreTrainedModel",
  821. ]