modeling_layoutlm.py 40 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. # Copyright 2018 The Microsoft Research Asia LayoutLM Team Authors and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LayoutLM model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. MaskedLMOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from ...utils.generic import merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_layoutlm import LayoutLMConfig
  37. logger = logging.get_logger(__name__)
  38. LayoutLMLayerNorm = nn.LayerNorm
  39. class LayoutLMEmbeddings(nn.Module):
  40. """Construct the embeddings from word, position and token_type embeddings."""
  41. def __init__(self, config):
  42. super().__init__()
  43. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  44. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  45. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  46. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  47. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  48. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.hidden_size)
  49. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  50. self.LayerNorm = LayoutLMLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  51. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  52. self.register_buffer(
  53. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  54. )
  55. def forward(
  56. self,
  57. input_ids=None,
  58. bbox=None,
  59. token_type_ids=None,
  60. position_ids=None,
  61. inputs_embeds=None,
  62. ):
  63. if input_ids is not None:
  64. input_shape = input_ids.size()
  65. else:
  66. input_shape = inputs_embeds.size()[:-1]
  67. seq_length = input_shape[1]
  68. device = input_ids.device if input_ids is not None else inputs_embeds.device
  69. if position_ids is None:
  70. position_ids = self.position_ids[:, :seq_length]
  71. if token_type_ids is None:
  72. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  73. if inputs_embeds is None:
  74. inputs_embeds = self.word_embeddings(input_ids)
  75. words_embeddings = inputs_embeds
  76. position_embeddings = self.position_embeddings(position_ids)
  77. try:
  78. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  79. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  80. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  81. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  82. except IndexError as e:
  83. raise IndexError("The `bbox`coordinate values should be within 0-1000 range.") from e
  84. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  85. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  86. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  87. embeddings = (
  88. words_embeddings
  89. + position_embeddings
  90. + left_position_embeddings
  91. + upper_position_embeddings
  92. + right_position_embeddings
  93. + lower_position_embeddings
  94. + h_position_embeddings
  95. + w_position_embeddings
  96. + token_type_embeddings
  97. )
  98. embeddings = self.LayerNorm(embeddings)
  99. embeddings = self.dropout(embeddings)
  100. return embeddings
  101. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  102. def eager_attention_forward(
  103. module: nn.Module,
  104. query: torch.Tensor,
  105. key: torch.Tensor,
  106. value: torch.Tensor,
  107. attention_mask: torch.Tensor | None,
  108. scaling: float,
  109. dropout: float = 0.0,
  110. **kwargs,
  111. ):
  112. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  113. if attention_mask is not None:
  114. attn_weights = attn_weights + attention_mask
  115. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  116. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  117. attn_output = torch.matmul(attn_weights, value)
  118. attn_output = attn_output.transpose(1, 2).contiguous()
  119. return attn_output, attn_weights
  120. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->LayoutLM
  121. class LayoutLMSelfAttention(nn.Module):
  122. def __init__(self, config):
  123. super().__init__()
  124. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  125. raise ValueError(
  126. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  127. f"heads ({config.num_attention_heads})"
  128. )
  129. self.config = config
  130. self.num_attention_heads = config.num_attention_heads
  131. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  132. self.all_head_size = self.num_attention_heads * self.attention_head_size
  133. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  134. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  135. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  136. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  137. self.attention_dropout = config.attention_probs_dropout_prob
  138. self.scaling = self.attention_head_size**-0.5
  139. def forward(
  140. self,
  141. hidden_states: torch.Tensor,
  142. attention_mask: torch.FloatTensor | None = None,
  143. **kwargs: Unpack[TransformersKwargs],
  144. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  145. input_shape = hidden_states.shape[:-1]
  146. hidden_shape = (*input_shape, -1, self.attention_head_size)
  147. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  148. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  149. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  150. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  151. self.config._attn_implementation, eager_attention_forward
  152. )
  153. attn_output, attn_weights = attention_interface(
  154. self,
  155. query_states,
  156. key_states,
  157. value_states,
  158. attention_mask,
  159. dropout=0.0 if not self.training else self.attention_dropout,
  160. scaling=self.scaling,
  161. **kwargs,
  162. )
  163. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  164. return attn_output, attn_weights
  165. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->LayoutLM
  166. class LayoutLMSelfOutput(nn.Module):
  167. def __init__(self, config):
  168. super().__init__()
  169. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  170. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  171. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  172. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  173. hidden_states = self.dense(hidden_states)
  174. hidden_states = self.dropout(hidden_states)
  175. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  176. return hidden_states
  177. # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->LayoutLM
  178. class LayoutLMAttention(nn.Module):
  179. def __init__(self, config):
  180. super().__init__()
  181. self.self = LayoutLMSelfAttention(config)
  182. self.output = LayoutLMSelfOutput(config)
  183. def forward(
  184. self,
  185. hidden_states: torch.Tensor,
  186. attention_mask: torch.FloatTensor | None = None,
  187. **kwargs: Unpack[TransformersKwargs],
  188. ) -> torch.Tensor:
  189. residual = hidden_states
  190. hidden_states, _ = self.self(
  191. hidden_states,
  192. attention_mask=attention_mask,
  193. **kwargs,
  194. )
  195. hidden_states = self.output(hidden_states, residual)
  196. return hidden_states
  197. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  198. class LayoutLMIntermediate(nn.Module):
  199. def __init__(self, config):
  200. super().__init__()
  201. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  202. if isinstance(config.hidden_act, str):
  203. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  204. else:
  205. self.intermediate_act_fn = config.hidden_act
  206. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  207. hidden_states = self.dense(hidden_states)
  208. hidden_states = self.intermediate_act_fn(hidden_states)
  209. return hidden_states
  210. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
  211. class LayoutLMOutput(nn.Module):
  212. def __init__(self, config):
  213. super().__init__()
  214. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  215. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  216. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  217. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  218. hidden_states = self.dense(hidden_states)
  219. hidden_states = self.dropout(hidden_states)
  220. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  221. return hidden_states
  222. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->LayoutLM
  223. class LayoutLMLayer(GradientCheckpointingLayer):
  224. def __init__(self, config):
  225. super().__init__()
  226. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  227. self.seq_len_dim = 1
  228. self.attention = LayoutLMAttention(config)
  229. self.intermediate = LayoutLMIntermediate(config)
  230. self.output = LayoutLMOutput(config)
  231. def forward(
  232. self,
  233. hidden_states: torch.Tensor,
  234. attention_mask: torch.FloatTensor | None = None,
  235. **kwargs: Unpack[TransformersKwargs],
  236. ) -> torch.Tensor:
  237. hidden_states = self.attention(
  238. hidden_states,
  239. attention_mask=attention_mask,
  240. **kwargs,
  241. )
  242. hidden_states = apply_chunking_to_forward(
  243. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, hidden_states
  244. )
  245. return hidden_states
  246. def feed_forward_chunk(self, attention_output):
  247. intermediate_output = self.intermediate(attention_output)
  248. layer_output = self.output(intermediate_output, attention_output)
  249. return layer_output
  250. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->LayoutLM
  251. class LayoutLMEncoder(nn.Module):
  252. def __init__(self, config):
  253. super().__init__()
  254. self.config = config
  255. self.layer = nn.ModuleList([LayoutLMLayer(config) for i in range(config.num_hidden_layers)])
  256. self.gradient_checkpointing = False
  257. def forward(
  258. self,
  259. hidden_states: torch.Tensor,
  260. attention_mask: torch.FloatTensor | None = None,
  261. **kwargs: Unpack[TransformersKwargs],
  262. ) -> BaseModelOutput:
  263. for layer_module in self.layer:
  264. hidden_states = layer_module(
  265. hidden_states,
  266. attention_mask,
  267. **kwargs,
  268. )
  269. return BaseModelOutput(
  270. last_hidden_state=hidden_states,
  271. )
  272. # Copied from transformers.models.bert.modeling_bert.BertPooler
  273. class LayoutLMPooler(nn.Module):
  274. def __init__(self, config):
  275. super().__init__()
  276. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  277. self.activation = nn.Tanh()
  278. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  279. # We "pool" the model by simply taking the hidden state corresponding
  280. # to the first token.
  281. first_token_tensor = hidden_states[:, 0]
  282. pooled_output = self.dense(first_token_tensor)
  283. pooled_output = self.activation(pooled_output)
  284. return pooled_output
  285. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->LayoutLM
  286. class LayoutLMPredictionHeadTransform(nn.Module):
  287. def __init__(self, config):
  288. super().__init__()
  289. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  290. if isinstance(config.hidden_act, str):
  291. self.transform_act_fn = ACT2FN[config.hidden_act]
  292. else:
  293. self.transform_act_fn = config.hidden_act
  294. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  295. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  296. hidden_states = self.dense(hidden_states)
  297. hidden_states = self.transform_act_fn(hidden_states)
  298. hidden_states = self.LayerNorm(hidden_states)
  299. return hidden_states
  300. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->LayoutLM
  301. class LayoutLMLMPredictionHead(nn.Module):
  302. def __init__(self, config):
  303. super().__init__()
  304. self.transform = LayoutLMPredictionHeadTransform(config)
  305. # The output weights are the same as the input embeddings, but there is
  306. # an output-only bias for each token.
  307. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  308. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  309. def forward(self, hidden_states):
  310. hidden_states = self.transform(hidden_states)
  311. hidden_states = self.decoder(hidden_states)
  312. return hidden_states
  313. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LayoutLM
  314. class LayoutLMOnlyMLMHead(nn.Module):
  315. def __init__(self, config):
  316. super().__init__()
  317. self.predictions = LayoutLMLMPredictionHead(config)
  318. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  319. prediction_scores = self.predictions(sequence_output)
  320. return prediction_scores
  321. @auto_docstring
  322. class LayoutLMPreTrainedModel(PreTrainedModel):
  323. config: LayoutLMConfig
  324. base_model_prefix = "layoutlm"
  325. supports_gradient_checkpointing = True
  326. _can_record_outputs = {
  327. "hidden_states": LayoutLMLayer,
  328. "attentions": LayoutLMSelfAttention,
  329. }
  330. @torch.no_grad()
  331. def _init_weights(self, module):
  332. """Initialize the weights"""
  333. super()._init_weights(module)
  334. if isinstance(module, LayoutLMLMPredictionHead):
  335. init.zeros_(module.bias)
  336. elif isinstance(module, LayoutLMEmbeddings):
  337. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  338. @auto_docstring
  339. class LayoutLMModel(LayoutLMPreTrainedModel):
  340. def __init__(self, config):
  341. super().__init__(config)
  342. self.config = config
  343. self.embeddings = LayoutLMEmbeddings(config)
  344. self.encoder = LayoutLMEncoder(config)
  345. self.pooler = LayoutLMPooler(config)
  346. # Initialize weights and apply final processing
  347. self.post_init()
  348. def get_input_embeddings(self):
  349. return self.embeddings.word_embeddings
  350. def set_input_embeddings(self, value):
  351. self.embeddings.word_embeddings = value
  352. @merge_with_config_defaults
  353. @capture_outputs
  354. @auto_docstring
  355. def forward(
  356. self,
  357. input_ids: torch.LongTensor | None = None,
  358. bbox: torch.LongTensor | None = None,
  359. attention_mask: torch.FloatTensor | None = None,
  360. token_type_ids: torch.LongTensor | None = None,
  361. position_ids: torch.LongTensor | None = None,
  362. inputs_embeds: torch.FloatTensor | None = None,
  363. **kwargs: Unpack[TransformersKwargs],
  364. ) -> tuple | BaseModelOutputWithPooling:
  365. r"""
  366. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  367. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  368. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  369. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  370. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  371. Examples:
  372. ```python
  373. >>> from transformers import AutoTokenizer, LayoutLMModel
  374. >>> import torch
  375. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  376. >>> model = LayoutLMModel.from_pretrained("microsoft/layoutlm-base-uncased")
  377. >>> words = ["Hello", "world"]
  378. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  379. >>> token_boxes = []
  380. >>> for word, box in zip(words, normalized_word_boxes):
  381. ... word_tokens = tokenizer.tokenize(word)
  382. ... token_boxes.extend([box] * len(word_tokens))
  383. >>> # add bounding boxes of cls + sep tokens
  384. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  385. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  386. >>> input_ids = encoding["input_ids"]
  387. >>> attention_mask = encoding["attention_mask"]
  388. >>> token_type_ids = encoding["token_type_ids"]
  389. >>> bbox = torch.tensor([token_boxes])
  390. >>> outputs = model(
  391. ... input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids
  392. ... )
  393. >>> last_hidden_states = outputs.last_hidden_state
  394. ```"""
  395. if input_ids is not None and inputs_embeds is not None:
  396. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  397. elif input_ids is not None:
  398. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  399. input_shape = input_ids.size()
  400. elif inputs_embeds is not None:
  401. input_shape = inputs_embeds.size()[:-1]
  402. else:
  403. raise ValueError("You have to specify either input_ids or inputs_embeds")
  404. device = input_ids.device if input_ids is not None else inputs_embeds.device
  405. if attention_mask is None:
  406. attention_mask = torch.ones(input_shape, device=device)
  407. if token_type_ids is None:
  408. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  409. if bbox is None:
  410. bbox = torch.zeros(input_shape + (4,), dtype=torch.long, device=device)
  411. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  412. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  413. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  414. embedding_output = self.embeddings(
  415. input_ids=input_ids,
  416. bbox=bbox,
  417. position_ids=position_ids,
  418. token_type_ids=token_type_ids,
  419. inputs_embeds=inputs_embeds,
  420. )
  421. encoder_outputs = self.encoder(
  422. embedding_output,
  423. extended_attention_mask,
  424. **kwargs,
  425. )
  426. sequence_output = encoder_outputs[0]
  427. pooled_output = self.pooler(sequence_output)
  428. return BaseModelOutputWithPooling(
  429. last_hidden_state=sequence_output,
  430. pooler_output=pooled_output,
  431. )
  432. @auto_docstring
  433. class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
  434. _tied_weights_keys = {
  435. "cls.predictions.decoder.bias": "cls.predictions.bias",
  436. "cls.predictions.decoder.weight": "layoutlm.embeddings.word_embeddings.weight",
  437. }
  438. def __init__(self, config):
  439. super().__init__(config)
  440. self.layoutlm = LayoutLMModel(config)
  441. self.cls = LayoutLMOnlyMLMHead(config)
  442. # Initialize weights and apply final processing
  443. self.post_init()
  444. def get_input_embeddings(self):
  445. return self.layoutlm.embeddings.word_embeddings
  446. def get_output_embeddings(self):
  447. return self.cls.predictions.decoder
  448. def set_output_embeddings(self, new_embeddings):
  449. self.cls.predictions.decoder = new_embeddings
  450. self.cls.predictions.bias = new_embeddings.bias
  451. @can_return_tuple
  452. @auto_docstring
  453. def forward(
  454. self,
  455. input_ids: torch.LongTensor | None = None,
  456. bbox: torch.LongTensor | None = None,
  457. attention_mask: torch.FloatTensor | None = None,
  458. token_type_ids: torch.LongTensor | None = None,
  459. position_ids: torch.LongTensor | None = None,
  460. inputs_embeds: torch.FloatTensor | None = None,
  461. labels: torch.LongTensor | None = None,
  462. **kwargs: Unpack[TransformersKwargs],
  463. ) -> tuple | MaskedLMOutput:
  464. r"""
  465. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  466. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  467. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  468. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  469. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  470. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  471. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  472. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  473. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  474. Examples:
  475. ```python
  476. >>> from transformers import AutoTokenizer, LayoutLMForMaskedLM
  477. >>> import torch
  478. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  479. >>> model = LayoutLMForMaskedLM.from_pretrained("microsoft/layoutlm-base-uncased")
  480. >>> words = ["Hello", "[MASK]"]
  481. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  482. >>> token_boxes = []
  483. >>> for word, box in zip(words, normalized_word_boxes):
  484. ... word_tokens = tokenizer.tokenize(word)
  485. ... token_boxes.extend([box] * len(word_tokens))
  486. >>> # add bounding boxes of cls + sep tokens
  487. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  488. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  489. >>> input_ids = encoding["input_ids"]
  490. >>> attention_mask = encoding["attention_mask"]
  491. >>> token_type_ids = encoding["token_type_ids"]
  492. >>> bbox = torch.tensor([token_boxes])
  493. >>> labels = tokenizer("Hello world", return_tensors="pt")["input_ids"]
  494. >>> outputs = model(
  495. ... input_ids=input_ids,
  496. ... bbox=bbox,
  497. ... attention_mask=attention_mask,
  498. ... token_type_ids=token_type_ids,
  499. ... labels=labels,
  500. ... )
  501. >>> loss = outputs.loss
  502. ```"""
  503. outputs = self.layoutlm(
  504. input_ids,
  505. bbox,
  506. attention_mask=attention_mask,
  507. token_type_ids=token_type_ids,
  508. position_ids=position_ids,
  509. inputs_embeds=inputs_embeds,
  510. **kwargs,
  511. )
  512. sequence_output = outputs[0]
  513. prediction_scores = self.cls(sequence_output)
  514. masked_lm_loss = None
  515. if labels is not None:
  516. loss_fct = CrossEntropyLoss()
  517. masked_lm_loss = loss_fct(
  518. prediction_scores.view(-1, self.config.vocab_size),
  519. labels.view(-1),
  520. )
  521. return MaskedLMOutput(
  522. loss=masked_lm_loss,
  523. logits=prediction_scores,
  524. hidden_states=outputs.hidden_states,
  525. attentions=outputs.attentions,
  526. )
  527. @auto_docstring(
  528. custom_intro="""
  529. LayoutLM Model with a sequence classification head on top (a linear layer on top of the pooled output) e.g. for
  530. document image classification tasks such as the [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  531. """
  532. )
  533. class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
  534. def __init__(self, config):
  535. super().__init__(config)
  536. self.num_labels = config.num_labels
  537. self.layoutlm = LayoutLMModel(config)
  538. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  539. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  540. # Initialize weights and apply final processing
  541. self.post_init()
  542. def get_input_embeddings(self):
  543. return self.layoutlm.embeddings.word_embeddings
  544. @can_return_tuple
  545. @auto_docstring
  546. def forward(
  547. self,
  548. input_ids: torch.LongTensor | None = None,
  549. bbox: torch.LongTensor | None = None,
  550. attention_mask: torch.FloatTensor | None = None,
  551. token_type_ids: torch.LongTensor | None = None,
  552. position_ids: torch.LongTensor | None = None,
  553. inputs_embeds: torch.FloatTensor | None = None,
  554. labels: torch.LongTensor | None = None,
  555. **kwargs: Unpack[TransformersKwargs],
  556. ) -> tuple | SequenceClassifierOutput:
  557. r"""
  558. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  559. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  560. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  561. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  562. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  563. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  564. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  565. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  566. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  567. Examples:
  568. ```python
  569. >>> from transformers import AutoTokenizer, LayoutLMForSequenceClassification
  570. >>> import torch
  571. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  572. >>> model = LayoutLMForSequenceClassification.from_pretrained("microsoft/layoutlm-base-uncased")
  573. >>> words = ["Hello", "world"]
  574. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  575. >>> token_boxes = []
  576. >>> for word, box in zip(words, normalized_word_boxes):
  577. ... word_tokens = tokenizer.tokenize(word)
  578. ... token_boxes.extend([box] * len(word_tokens))
  579. >>> # add bounding boxes of cls + sep tokens
  580. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  581. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  582. >>> input_ids = encoding["input_ids"]
  583. >>> attention_mask = encoding["attention_mask"]
  584. >>> token_type_ids = encoding["token_type_ids"]
  585. >>> bbox = torch.tensor([token_boxes])
  586. >>> sequence_label = torch.tensor([1])
  587. >>> outputs = model(
  588. ... input_ids=input_ids,
  589. ... bbox=bbox,
  590. ... attention_mask=attention_mask,
  591. ... token_type_ids=token_type_ids,
  592. ... labels=sequence_label,
  593. ... )
  594. >>> loss = outputs.loss
  595. >>> logits = outputs.logits
  596. ```"""
  597. outputs = self.layoutlm(
  598. input_ids=input_ids,
  599. bbox=bbox,
  600. attention_mask=attention_mask,
  601. token_type_ids=token_type_ids,
  602. position_ids=position_ids,
  603. inputs_embeds=inputs_embeds,
  604. **kwargs,
  605. )
  606. pooled_output = outputs[1]
  607. pooled_output = self.dropout(pooled_output)
  608. logits = self.classifier(pooled_output)
  609. loss = None
  610. if labels is not None:
  611. if self.config.problem_type is None:
  612. if self.num_labels == 1:
  613. self.config.problem_type = "regression"
  614. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  615. self.config.problem_type = "single_label_classification"
  616. else:
  617. self.config.problem_type = "multi_label_classification"
  618. if self.config.problem_type == "regression":
  619. loss_fct = MSELoss()
  620. if self.num_labels == 1:
  621. loss = loss_fct(logits.squeeze(), labels.squeeze())
  622. else:
  623. loss = loss_fct(logits, labels)
  624. elif self.config.problem_type == "single_label_classification":
  625. loss_fct = CrossEntropyLoss()
  626. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  627. elif self.config.problem_type == "multi_label_classification":
  628. loss_fct = BCEWithLogitsLoss()
  629. loss = loss_fct(logits, labels)
  630. return SequenceClassifierOutput(
  631. loss=loss,
  632. logits=logits,
  633. hidden_states=outputs.hidden_states,
  634. attentions=outputs.attentions,
  635. )
  636. @auto_docstring(
  637. custom_intro="""
  638. LayoutLM Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  639. sequence labeling (information extraction) tasks such as the [FUNSD](https://guillaumejaume.github.io/FUNSD/)
  640. dataset and the [SROIE](https://rrc.cvc.uab.es/?ch=13) dataset.
  641. """
  642. )
  643. class LayoutLMForTokenClassification(LayoutLMPreTrainedModel):
  644. def __init__(self, config):
  645. super().__init__(config)
  646. self.num_labels = config.num_labels
  647. self.layoutlm = LayoutLMModel(config)
  648. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  649. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  650. # Initialize weights and apply final processing
  651. self.post_init()
  652. def get_input_embeddings(self):
  653. return self.layoutlm.embeddings.word_embeddings
  654. @can_return_tuple
  655. @auto_docstring
  656. def forward(
  657. self,
  658. input_ids: torch.LongTensor | None = None,
  659. bbox: torch.LongTensor | None = None,
  660. attention_mask: torch.FloatTensor | None = None,
  661. token_type_ids: torch.LongTensor | None = None,
  662. position_ids: torch.LongTensor | None = None,
  663. inputs_embeds: torch.FloatTensor | None = None,
  664. labels: torch.LongTensor | None = None,
  665. **kwargs: Unpack[TransformersKwargs],
  666. ) -> tuple | TokenClassifierOutput:
  667. r"""
  668. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  669. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  670. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  671. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  672. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  673. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  674. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  675. Examples:
  676. ```python
  677. >>> from transformers import AutoTokenizer, LayoutLMForTokenClassification
  678. >>> import torch
  679. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
  680. >>> model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased")
  681. >>> words = ["Hello", "world"]
  682. >>> normalized_word_boxes = [637, 773, 693, 782], [698, 773, 733, 782]
  683. >>> token_boxes = []
  684. >>> for word, box in zip(words, normalized_word_boxes):
  685. ... word_tokens = tokenizer.tokenize(word)
  686. ... token_boxes.extend([box] * len(word_tokens))
  687. >>> # add bounding boxes of cls + sep tokens
  688. >>> token_boxes = [[0, 0, 0, 0]] + token_boxes + [[1000, 1000, 1000, 1000]]
  689. >>> encoding = tokenizer(" ".join(words), return_tensors="pt")
  690. >>> input_ids = encoding["input_ids"]
  691. >>> attention_mask = encoding["attention_mask"]
  692. >>> token_type_ids = encoding["token_type_ids"]
  693. >>> bbox = torch.tensor([token_boxes])
  694. >>> token_labels = torch.tensor([1, 1, 0, 0]).unsqueeze(0) # batch size of 1
  695. >>> outputs = model(
  696. ... input_ids=input_ids,
  697. ... bbox=bbox,
  698. ... attention_mask=attention_mask,
  699. ... token_type_ids=token_type_ids,
  700. ... labels=token_labels,
  701. ... )
  702. >>> loss = outputs.loss
  703. >>> logits = outputs.logits
  704. ```"""
  705. outputs = self.layoutlm(
  706. input_ids=input_ids,
  707. bbox=bbox,
  708. attention_mask=attention_mask,
  709. token_type_ids=token_type_ids,
  710. position_ids=position_ids,
  711. inputs_embeds=inputs_embeds,
  712. **kwargs,
  713. )
  714. sequence_output = outputs[0]
  715. sequence_output = self.dropout(sequence_output)
  716. logits = self.classifier(sequence_output)
  717. loss = None
  718. if labels is not None:
  719. loss_fct = CrossEntropyLoss()
  720. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  721. return TokenClassifierOutput(
  722. loss=loss,
  723. logits=logits,
  724. hidden_states=outputs.hidden_states,
  725. attentions=outputs.attentions,
  726. )
  727. @auto_docstring
  728. class LayoutLMForQuestionAnswering(LayoutLMPreTrainedModel):
  729. def __init__(self, config, has_visual_segment_embedding=True):
  730. r"""
  731. has_visual_segment_embedding (`bool`, *optional*, defaults to `True`):
  732. Whether or not to add visual segment embeddings.
  733. """
  734. super().__init__(config)
  735. self.num_labels = config.num_labels
  736. self.layoutlm = LayoutLMModel(config)
  737. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  738. # Initialize weights and apply final processing
  739. self.post_init()
  740. def get_input_embeddings(self):
  741. return self.layoutlm.embeddings.word_embeddings
  742. @can_return_tuple
  743. @auto_docstring
  744. def forward(
  745. self,
  746. input_ids: torch.LongTensor | None = None,
  747. bbox: torch.LongTensor | None = None,
  748. attention_mask: torch.FloatTensor | None = None,
  749. token_type_ids: torch.LongTensor | None = None,
  750. position_ids: torch.LongTensor | None = None,
  751. inputs_embeds: torch.FloatTensor | None = None,
  752. start_positions: torch.LongTensor | None = None,
  753. end_positions: torch.LongTensor | None = None,
  754. **kwargs: Unpack[TransformersKwargs],
  755. ) -> tuple | QuestionAnsweringModelOutput:
  756. r"""
  757. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  758. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  759. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  760. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  761. y1) represents the position of the lower right corner. See [Overview](#Overview) for normalization.
  762. Example:
  763. In the example below, we prepare a question + context pair for the LayoutLM model. It will give us a prediction
  764. of what it thinks the answer is (the span of the answer within the texts parsed from the image).
  765. ```python
  766. >>> from transformers import AutoTokenizer, LayoutLMForQuestionAnswering
  767. >>> from datasets import load_dataset
  768. >>> import torch
  769. >>> tokenizer = AutoTokenizer.from_pretrained("impira/layoutlm-document-qa", add_prefix_space=True)
  770. >>> model = LayoutLMForQuestionAnswering.from_pretrained("impira/layoutlm-document-qa", revision="1e3ebac")
  771. >>> dataset = load_dataset("nielsr/funsd", split="train")
  772. >>> example = dataset[0]
  773. >>> question = "what's his name?"
  774. >>> words = example["words"]
  775. >>> boxes = example["bboxes"]
  776. >>> encoding = tokenizer(
  777. ... question.split(), words, is_split_into_words=True, return_token_type_ids=True, return_tensors="pt"
  778. ... )
  779. >>> bbox = []
  780. >>> for i, s, w in zip(encoding.input_ids[0], encoding.sequence_ids(0), encoding.word_ids(0)):
  781. ... if s == 1:
  782. ... bbox.append(boxes[w])
  783. ... elif i == tokenizer.sep_token_id:
  784. ... bbox.append([1000] * 4)
  785. ... else:
  786. ... bbox.append([0] * 4)
  787. >>> encoding["bbox"] = torch.tensor([bbox])
  788. >>> word_ids = encoding.word_ids(0)
  789. >>> outputs = model(**encoding)
  790. >>> loss = outputs.loss
  791. >>> start_scores = outputs.start_logits
  792. >>> end_scores = outputs.end_logits
  793. >>> start, end = word_ids[start_scores.argmax(-1)], word_ids[end_scores.argmax(-1)]
  794. >>> print(" ".join(words[start : end + 1]))
  795. M. Hamann P. Harper, P. Martinez
  796. ```"""
  797. outputs = self.layoutlm(
  798. input_ids=input_ids,
  799. bbox=bbox,
  800. attention_mask=attention_mask,
  801. token_type_ids=token_type_ids,
  802. position_ids=position_ids,
  803. inputs_embeds=inputs_embeds,
  804. **kwargs,
  805. )
  806. sequence_output = outputs[0]
  807. logits = self.qa_outputs(sequence_output)
  808. start_logits, end_logits = logits.split(1, dim=-1)
  809. start_logits = start_logits.squeeze(-1).contiguous()
  810. end_logits = end_logits.squeeze(-1).contiguous()
  811. total_loss = None
  812. if start_positions is not None and end_positions is not None:
  813. # If we are on multi-GPU, split add a dimension
  814. if len(start_positions.size()) > 1:
  815. start_positions = start_positions.squeeze(-1)
  816. if len(end_positions.size()) > 1:
  817. end_positions = end_positions.squeeze(-1)
  818. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  819. ignored_index = start_logits.size(1)
  820. start_positions = start_positions.clamp(0, ignored_index)
  821. end_positions = end_positions.clamp(0, ignored_index)
  822. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  823. start_loss = loss_fct(start_logits, start_positions)
  824. end_loss = loss_fct(end_logits, end_positions)
  825. total_loss = (start_loss + end_loss) / 2
  826. return QuestionAnsweringModelOutput(
  827. loss=total_loss,
  828. start_logits=start_logits,
  829. end_logits=end_logits,
  830. hidden_states=outputs.hidden_states,
  831. attentions=outputs.attentions,
  832. )
  833. __all__ = [
  834. "LayoutLMForMaskedLM",
  835. "LayoutLMForSequenceClassification",
  836. "LayoutLMForTokenClassification",
  837. "LayoutLMForQuestionAnswering",
  838. "LayoutLMModel",
  839. "LayoutLMPreTrainedModel",
  840. ]