modeling_layoutlmv3.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159
  1. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LayoutLMv3 model."""
  15. import collections
  16. import math
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ... import initialization as init
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import (
  34. auto_docstring,
  35. can_return_tuple,
  36. logging,
  37. torch_int,
  38. )
  39. from ...utils.generic import TransformersKwargs, merge_with_config_defaults
  40. from ...utils.output_capturing import capture_outputs
  41. from .configuration_layoutlmv3 import LayoutLMv3Config
  42. logger = logging.get_logger(__name__)
  43. class LayoutLMv3PatchEmbeddings(nn.Module):
  44. """LayoutLMv3 image (patch) embeddings. This class also automatically interpolates the position embeddings for varying
  45. image sizes."""
  46. def __init__(self, config):
  47. super().__init__()
  48. image_size = (
  49. config.input_size
  50. if isinstance(config.input_size, collections.abc.Iterable)
  51. else (config.input_size, config.input_size)
  52. )
  53. patch_size = (
  54. config.patch_size
  55. if isinstance(config.patch_size, collections.abc.Iterable)
  56. else (config.patch_size, config.patch_size)
  57. )
  58. self.patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  59. self.proj = nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)
  60. def forward(self, pixel_values, position_embedding=None):
  61. embeddings = self.proj(pixel_values)
  62. if position_embedding is not None:
  63. # interpolate the position embedding to the corresponding size
  64. position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1)
  65. position_embedding = position_embedding.permute(0, 3, 1, 2)
  66. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  67. position_embedding = F.interpolate(position_embedding, size=(patch_height, patch_width), mode="bicubic")
  68. embeddings = embeddings + position_embedding
  69. embeddings = embeddings.flatten(2).transpose(1, 2)
  70. return embeddings
  71. class LayoutLMv3TextEmbeddings(nn.Module):
  72. """
  73. LayoutLMv3 text embeddings. Same as `RobertaEmbeddings` but with added spatial (layout) embeddings.
  74. """
  75. def __init__(self, config):
  76. super().__init__()
  77. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  78. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  79. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  80. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  81. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  82. self.register_buffer(
  83. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  84. )
  85. self.padding_idx = config.pad_token_id
  86. self.position_embeddings = nn.Embedding(
  87. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  88. )
  89. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  90. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  91. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  92. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  93. def calculate_spatial_position_embeddings(self, bbox):
  94. try:
  95. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  96. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  97. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  98. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  99. except IndexError as e:
  100. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  101. h_position_embeddings = self.h_position_embeddings(torch.clip(bbox[:, :, 3] - bbox[:, :, 1], 0, 1023))
  102. w_position_embeddings = self.w_position_embeddings(torch.clip(bbox[:, :, 2] - bbox[:, :, 0], 0, 1023))
  103. # below is the difference between LayoutLMEmbeddingsV2 (torch.cat) and LayoutLMEmbeddingsV1 (add)
  104. spatial_position_embeddings = torch.cat(
  105. [
  106. left_position_embeddings,
  107. upper_position_embeddings,
  108. right_position_embeddings,
  109. lower_position_embeddings,
  110. h_position_embeddings,
  111. w_position_embeddings,
  112. ],
  113. dim=-1,
  114. )
  115. return spatial_position_embeddings
  116. def create_position_ids_from_input_ids(self, input_ids, padding_idx):
  117. """
  118. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  119. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  120. """
  121. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  122. mask = input_ids.ne(padding_idx).int()
  123. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask)) * mask
  124. return incremental_indices.long() + padding_idx
  125. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  126. """
  127. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  128. """
  129. input_shape = inputs_embeds.size()[:-1]
  130. sequence_length = input_shape[1]
  131. position_ids = torch.arange(
  132. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  133. )
  134. return position_ids.unsqueeze(0).expand(input_shape)
  135. def forward(
  136. self,
  137. input_ids=None,
  138. bbox=None,
  139. token_type_ids=None,
  140. position_ids=None,
  141. inputs_embeds=None,
  142. ):
  143. if position_ids is None:
  144. if input_ids is not None:
  145. # Create the position ids from the input token ids. Any padded tokens remain padded.
  146. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx).to(
  147. input_ids.device
  148. )
  149. else:
  150. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  151. if input_ids is not None:
  152. input_shape = input_ids.size()
  153. else:
  154. input_shape = inputs_embeds.size()[:-1]
  155. if token_type_ids is None:
  156. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  157. if inputs_embeds is None:
  158. inputs_embeds = self.word_embeddings(input_ids)
  159. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  160. embeddings = inputs_embeds + token_type_embeddings
  161. position_embeddings = self.position_embeddings(position_ids)
  162. embeddings += position_embeddings
  163. spatial_position_embeddings = self.calculate_spatial_position_embeddings(bbox)
  164. embeddings = embeddings + spatial_position_embeddings
  165. embeddings = self.LayerNorm(embeddings)
  166. embeddings = self.dropout(embeddings)
  167. return embeddings
  168. class LayoutLMv3SelfAttention(nn.Module):
  169. def __init__(self, config):
  170. super().__init__()
  171. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  172. raise ValueError(
  173. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  174. f"heads ({config.num_attention_heads})"
  175. )
  176. self.num_attention_heads = config.num_attention_heads
  177. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  178. self.all_head_size = self.num_attention_heads * self.attention_head_size
  179. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  180. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  181. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  182. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  183. self.has_relative_attention_bias = config.has_relative_attention_bias
  184. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  185. def cogview_attention(self, attention_scores, alpha=32):
  186. """
  187. https://huggingface.co/papers/2105.13290 Section 2.4 Stabilization of training: Precision Bottleneck Relaxation
  188. (PB-Relax). A replacement of the original nn.Softmax(dim=-1)(attention_scores). Seems the new attention_probs
  189. will result in a slower speed and a little bias. Can use torch.allclose(standard_attention_probs,
  190. cogview_attention_probs, atol=1e-08) for comparison. The smaller atol (e.g., 1e-08), the better.
  191. """
  192. scaled_attention_scores = attention_scores / alpha
  193. max_value = scaled_attention_scores.amax(dim=(-1)).unsqueeze(-1)
  194. new_attention_scores = (scaled_attention_scores - max_value) * alpha
  195. return nn.Softmax(dim=-1)(new_attention_scores)
  196. def forward(
  197. self,
  198. hidden_states,
  199. attention_mask=None,
  200. rel_pos=None,
  201. rel_2d_pos=None,
  202. **kwargs: Unpack[TransformersKwargs],
  203. ):
  204. batch_size = hidden_states.shape[0]
  205. query_layer = (
  206. self.query(hidden_states)
  207. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  208. .transpose(1, 2)
  209. )
  210. key_layer = (
  211. self.key(hidden_states)
  212. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  213. .transpose(1, 2)
  214. )
  215. value_layer = (
  216. self.value(hidden_states)
  217. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  218. .transpose(1, 2)
  219. )
  220. # Take the dot product between "query" and "key" to get the raw attention scores.
  221. # The attention scores QT K/√d could be significantly larger than input elements, and result in overflow.
  222. # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
  223. attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
  224. if self.has_relative_attention_bias and self.has_spatial_attention_bias:
  225. attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
  226. elif self.has_relative_attention_bias:
  227. attention_scores += rel_pos / math.sqrt(self.attention_head_size)
  228. if attention_mask is not None:
  229. # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
  230. attention_scores = attention_scores + attention_mask
  231. # Normalize the attention scores to probabilities.
  232. # Use the trick of the CogView paper to stabilize training
  233. attention_probs = self.cogview_attention(attention_scores)
  234. # This is actually dropping out entire tokens to attend to, which might
  235. # seem a bit unusual, but is taken from the original Transformer paper.
  236. attention_probs = self.dropout(attention_probs)
  237. context_layer = torch.matmul(attention_probs, value_layer)
  238. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  239. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  240. context_layer = context_layer.view(*new_context_layer_shape)
  241. return context_layer, attention_probs
  242. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput
  243. class LayoutLMv3SelfOutput(nn.Module):
  244. def __init__(self, config):
  245. super().__init__()
  246. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  247. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  248. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  249. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  250. hidden_states = self.dense(hidden_states)
  251. hidden_states = self.dropout(hidden_states)
  252. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  253. return hidden_states
  254. # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
  255. class LayoutLMv3Attention(nn.Module):
  256. def __init__(self, config):
  257. super().__init__()
  258. self.self = LayoutLMv3SelfAttention(config)
  259. self.output = LayoutLMv3SelfOutput(config)
  260. def forward(
  261. self,
  262. hidden_states,
  263. attention_mask=None,
  264. rel_pos=None,
  265. rel_2d_pos=None,
  266. **kwargs: Unpack[TransformersKwargs],
  267. ):
  268. residual = hidden_states
  269. attention_output, _ = self.self(
  270. hidden_states,
  271. attention_mask,
  272. rel_pos=rel_pos,
  273. rel_2d_pos=rel_2d_pos,
  274. **kwargs,
  275. )
  276. attention_output = self.output(attention_output, residual)
  277. return attention_output
  278. # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
  279. class LayoutLMv3Layer(GradientCheckpointingLayer):
  280. def __init__(self, config):
  281. super().__init__()
  282. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  283. self.seq_len_dim = 1
  284. self.attention = LayoutLMv3Attention(config)
  285. self.intermediate = LayoutLMv3Intermediate(config)
  286. self.output = LayoutLMv3Output(config)
  287. def forward(
  288. self,
  289. hidden_states,
  290. attention_mask=None,
  291. output_attentions=False,
  292. rel_pos=None,
  293. rel_2d_pos=None,
  294. **kwargs: Unpack[TransformersKwargs],
  295. ):
  296. attention_output = self.attention(
  297. hidden_states,
  298. attention_mask,
  299. rel_pos=rel_pos,
  300. rel_2d_pos=rel_2d_pos,
  301. )
  302. layer_output = apply_chunking_to_forward(
  303. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  304. )
  305. return layer_output
  306. def feed_forward_chunk(self, attention_output):
  307. intermediate_output = self.intermediate(attention_output)
  308. layer_output = self.output(intermediate_output, attention_output)
  309. return layer_output
  310. class LayoutLMv3Encoder(nn.Module):
  311. def __init__(self, config):
  312. super().__init__()
  313. self.config = config
  314. self.layer = nn.ModuleList([LayoutLMv3Layer(config) for _ in range(config.num_hidden_layers)])
  315. self.gradient_checkpointing = False
  316. self.has_relative_attention_bias = config.has_relative_attention_bias
  317. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  318. if self.has_relative_attention_bias:
  319. self.rel_pos_bins = config.rel_pos_bins
  320. self.max_rel_pos = config.max_rel_pos
  321. self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
  322. if self.has_spatial_attention_bias:
  323. self.max_rel_2d_pos = config.max_rel_2d_pos
  324. self.rel_2d_pos_bins = config.rel_2d_pos_bins
  325. self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  326. self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  327. def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  328. ret = 0
  329. if bidirectional:
  330. num_buckets //= 2
  331. ret += (relative_position > 0).long() * num_buckets
  332. n = torch.abs(relative_position)
  333. else:
  334. n = torch.max(-relative_position, torch.zeros_like(relative_position))
  335. # now n is in the range [0, inf)
  336. # half of the buckets are for exact increments in positions
  337. max_exact = num_buckets // 2
  338. is_small = n < max_exact
  339. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  340. val_if_large = max_exact + (
  341. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  342. ).to(torch.long)
  343. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  344. ret += torch.where(is_small, n, val_if_large)
  345. return ret
  346. def _cal_1d_pos_emb(self, position_ids):
  347. rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
  348. rel_pos = self.relative_position_bucket(
  349. rel_pos_mat,
  350. num_buckets=self.rel_pos_bins,
  351. max_distance=self.max_rel_pos,
  352. )
  353. # Since this is a simple indexing operation that is independent of the input,
  354. # no need to track gradients for this operation
  355. #
  356. # Without this no_grad context, training speed slows down significantly
  357. with torch.no_grad():
  358. rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
  359. rel_pos = rel_pos.contiguous()
  360. return rel_pos
  361. def _cal_2d_pos_emb(self, bbox):
  362. position_coord_x = bbox[:, :, 0]
  363. position_coord_y = bbox[:, :, 3]
  364. rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
  365. rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
  366. rel_pos_x = self.relative_position_bucket(
  367. rel_pos_x_2d_mat,
  368. num_buckets=self.rel_2d_pos_bins,
  369. max_distance=self.max_rel_2d_pos,
  370. )
  371. rel_pos_y = self.relative_position_bucket(
  372. rel_pos_y_2d_mat,
  373. num_buckets=self.rel_2d_pos_bins,
  374. max_distance=self.max_rel_2d_pos,
  375. )
  376. # Since this is a simple indexing operation that is independent of the input,
  377. # no need to track gradients for this operation
  378. #
  379. # Without this no_grad context, training speed slows down significantly
  380. with torch.no_grad():
  381. rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
  382. rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
  383. rel_pos_x = rel_pos_x.contiguous()
  384. rel_pos_y = rel_pos_y.contiguous()
  385. rel_2d_pos = rel_pos_x + rel_pos_y
  386. return rel_2d_pos
  387. def forward(
  388. self,
  389. hidden_states,
  390. bbox=None,
  391. attention_mask=None,
  392. position_ids=None,
  393. patch_height=None,
  394. patch_width=None,
  395. **kwargs: Unpack[TransformersKwargs],
  396. ):
  397. rel_pos = self._cal_1d_pos_emb(position_ids) if self.has_relative_attention_bias else None
  398. rel_2d_pos = self._cal_2d_pos_emb(bbox) if self.has_spatial_attention_bias else None
  399. for layer_module in self.layer:
  400. hidden_states = layer_module(
  401. hidden_states,
  402. attention_mask,
  403. rel_pos=rel_pos,
  404. rel_2d_pos=rel_2d_pos,
  405. **kwargs,
  406. )
  407. return BaseModelOutput(last_hidden_state=hidden_states)
  408. # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
  409. class LayoutLMv3Intermediate(nn.Module):
  410. def __init__(self, config):
  411. super().__init__()
  412. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  413. if isinstance(config.hidden_act, str):
  414. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  415. else:
  416. self.intermediate_act_fn = config.hidden_act
  417. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  418. hidden_states = self.dense(hidden_states)
  419. hidden_states = self.intermediate_act_fn(hidden_states)
  420. return hidden_states
  421. # Copied from transformers.models.roberta.modeling_roberta.RobertaOutput
  422. class LayoutLMv3Output(nn.Module):
  423. def __init__(self, config):
  424. super().__init__()
  425. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  426. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  427. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  428. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  429. hidden_states = self.dense(hidden_states)
  430. hidden_states = self.dropout(hidden_states)
  431. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  432. return hidden_states
  433. @auto_docstring
  434. class LayoutLMv3PreTrainedModel(PreTrainedModel):
  435. config: LayoutLMv3Config
  436. base_model_prefix = "layoutlmv3"
  437. input_modalities = ("image", "text")
  438. _can_record_outputs = {"hidden_states": LayoutLMv3Layer, "attentions": LayoutLMv3SelfAttention}
  439. @torch.no_grad()
  440. def _init_weights(self, module):
  441. """Initialize the weights"""
  442. super()._init_weights(module)
  443. if isinstance(module, LayoutLMv3Model):
  444. if self.config.visual_embed:
  445. init.zeros_(module.cls_token)
  446. init.zeros_(module.pos_embed)
  447. if hasattr(module, "visual_bbox"):
  448. init.copy_(module.visual_bbox, module.create_visual_bbox(image_size=(module.size, module.size)))
  449. elif isinstance(module, LayoutLMv3TextEmbeddings):
  450. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  451. @auto_docstring
  452. class LayoutLMv3Model(LayoutLMv3PreTrainedModel):
  453. def __init__(self, config):
  454. super().__init__(config)
  455. self.config = config
  456. if config.text_embed:
  457. self.embeddings = LayoutLMv3TextEmbeddings(config)
  458. if config.visual_embed:
  459. # use the default pre-training parameters for fine-tuning (e.g., input_size)
  460. # when the input_size is larger in fine-tuning, we will interpolate the position embeddings in forward
  461. self.patch_embed = LayoutLMv3PatchEmbeddings(config)
  462. self.size = int(config.input_size / config.patch_size)
  463. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  464. self.pos_embed = nn.Parameter(torch.zeros(1, self.size * self.size + 1, config.hidden_size))
  465. self.pos_drop = nn.Dropout(p=0.0)
  466. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  467. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  468. if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  469. self.register_buffer(
  470. "visual_bbox", self.create_visual_bbox(image_size=(self.size, self.size)), persistent=False
  471. )
  472. self.norm = nn.LayerNorm(config.hidden_size, eps=1e-6)
  473. self.encoder = LayoutLMv3Encoder(config)
  474. self.post_init()
  475. def get_input_embeddings(self):
  476. return self.embeddings.word_embeddings
  477. def set_input_embeddings(self, value):
  478. self.embeddings.word_embeddings = value
  479. def create_visual_bbox(self, image_size=(14, 14), max_len=1000):
  480. """
  481. Create the bounding boxes for the visual (patch) tokens.
  482. """
  483. visual_bbox_x = torch.div(
  484. torch.arange(0, max_len * (image_size[1] + 1), max_len), image_size[1], rounding_mode="trunc"
  485. )
  486. visual_bbox_y = torch.div(
  487. torch.arange(0, max_len * (image_size[0] + 1), max_len), image_size[0], rounding_mode="trunc"
  488. )
  489. visual_bbox = torch.stack(
  490. [
  491. visual_bbox_x[:-1].repeat(image_size[0], 1),
  492. visual_bbox_y[:-1].repeat(image_size[1], 1).transpose(0, 1),
  493. visual_bbox_x[1:].repeat(image_size[0], 1),
  494. visual_bbox_y[1:].repeat(image_size[1], 1).transpose(0, 1),
  495. ],
  496. dim=-1,
  497. ).view(-1, 4)
  498. cls_token_box = torch.tensor([[0 + 1, 0 + 1, max_len - 1, max_len - 1]])
  499. return torch.cat([cls_token_box, visual_bbox], dim=0)
  500. def calculate_visual_bbox(self, device, dtype, batch_size):
  501. visual_bbox = self.visual_bbox.repeat(batch_size, 1, 1)
  502. visual_bbox = visual_bbox.to(device).type(dtype)
  503. return visual_bbox
  504. def forward_image(self, pixel_values):
  505. embeddings = self.patch_embed(pixel_values)
  506. # add [CLS] token
  507. batch_size, seq_len, _ = embeddings.size()
  508. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  509. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  510. # add position embeddings
  511. if self.pos_embed is not None:
  512. embeddings = embeddings + self.pos_embed
  513. embeddings = self.pos_drop(embeddings)
  514. embeddings = self.norm(embeddings)
  515. return embeddings
  516. @merge_with_config_defaults
  517. @capture_outputs
  518. @auto_docstring
  519. def forward(
  520. self,
  521. input_ids: torch.LongTensor | None = None,
  522. bbox: torch.LongTensor | None = None,
  523. attention_mask: torch.FloatTensor | None = None,
  524. token_type_ids: torch.LongTensor | None = None,
  525. position_ids: torch.LongTensor | None = None,
  526. inputs_embeds: torch.FloatTensor | None = None,
  527. pixel_values: torch.FloatTensor | None = None,
  528. **kwargs: Unpack[TransformersKwargs],
  529. ) -> tuple | BaseModelOutput:
  530. r"""
  531. input_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`):
  532. Indices of input sequence tokens in the vocabulary.
  533. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  534. token. See `pixel_values` for `patch_sequence_length`.
  535. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  536. [`PreTrainedTokenizer.__call__`] for details.
  537. [What are input IDs?](../glossary#input-ids)
  538. bbox (`torch.LongTensor` of shape `(batch_size, token_sequence_length, 4)`, *optional*):
  539. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  540. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  541. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  542. y1) represents the position of the lower right corner.
  543. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  544. token. See `pixel_values` for `patch_sequence_length`.
  545. token_type_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`, *optional*):
  546. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  547. 1]`:
  548. - 0 corresponds to a *sentence A* token,
  549. - 1 corresponds to a *sentence B* token.
  550. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  551. token. See `pixel_values` for `patch_sequence_length`.
  552. [What are token type IDs?](../glossary#token-type-ids)
  553. position_ids (`torch.LongTensor` of shape `(batch_size, token_sequence_length)`, *optional*):
  554. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  555. config.max_position_embeddings - 1]`.
  556. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  557. token. See `pixel_values` for `patch_sequence_length`.
  558. [What are position IDs?](../glossary#position-ids)
  559. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, token_sequence_length, hidden_size)`, *optional*):
  560. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  561. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  562. model's internal embedding lookup matrix.
  563. Examples:
  564. ```python
  565. >>> from transformers import AutoProcessor, AutoModel
  566. >>> from datasets import load_dataset
  567. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  568. >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
  569. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  570. >>> example = dataset[0]
  571. >>> image = example["image"]
  572. >>> words = example["tokens"]
  573. >>> boxes = example["bboxes"]
  574. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  575. >>> outputs = model(**encoding)
  576. >>> last_hidden_states = outputs.last_hidden_state
  577. ```"""
  578. if input_ids is not None:
  579. input_shape = input_ids.size()
  580. batch_size, seq_length = input_shape
  581. device = input_ids.device
  582. elif inputs_embeds is not None:
  583. input_shape = inputs_embeds.size()[:-1]
  584. batch_size, seq_length = input_shape
  585. device = inputs_embeds.device
  586. elif pixel_values is not None:
  587. batch_size = len(pixel_values)
  588. device = pixel_values.device
  589. else:
  590. raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")
  591. if input_ids is not None or inputs_embeds is not None:
  592. if attention_mask is None:
  593. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  594. if token_type_ids is None:
  595. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  596. if bbox is None:
  597. bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
  598. embedding_output = self.embeddings(
  599. input_ids=input_ids,
  600. bbox=bbox,
  601. position_ids=position_ids,
  602. token_type_ids=token_type_ids,
  603. inputs_embeds=inputs_embeds,
  604. )
  605. final_bbox = final_position_ids = None
  606. patch_height = patch_width = None
  607. if pixel_values is not None:
  608. patch_height, patch_width = (
  609. torch_int(pixel_values.shape[2] / self.config.patch_size),
  610. torch_int(pixel_values.shape[3] / self.config.patch_size),
  611. )
  612. visual_embeddings = self.forward_image(pixel_values)
  613. visual_attention_mask = torch.ones(
  614. (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
  615. )
  616. if attention_mask is not None:
  617. attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
  618. else:
  619. attention_mask = visual_attention_mask
  620. if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  621. if self.config.has_spatial_attention_bias:
  622. visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
  623. if bbox is not None:
  624. final_bbox = torch.cat([bbox, visual_bbox], dim=1)
  625. else:
  626. final_bbox = visual_bbox
  627. visual_position_ids = torch.arange(
  628. 0, visual_embeddings.shape[1], dtype=torch.long, device=device
  629. ).repeat(batch_size, 1)
  630. if input_ids is not None or inputs_embeds is not None:
  631. position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
  632. position_ids = position_ids.expand(input_shape)
  633. final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
  634. else:
  635. final_position_ids = visual_position_ids
  636. if input_ids is not None or inputs_embeds is not None:
  637. embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
  638. else:
  639. embedding_output = visual_embeddings
  640. embedding_output = self.LayerNorm(embedding_output)
  641. embedding_output = self.dropout(embedding_output)
  642. elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
  643. if self.config.has_spatial_attention_bias:
  644. final_bbox = bbox
  645. if self.config.has_relative_attention_bias:
  646. position_ids = self.embeddings.position_ids[:, : input_shape[1]]
  647. position_ids = position_ids.expand_as(input_ids)
  648. final_position_ids = position_ids
  649. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  650. attention_mask, None, dtype=embedding_output.dtype
  651. )
  652. encoder_outputs = self.encoder(
  653. embedding_output,
  654. bbox=final_bbox,
  655. position_ids=final_position_ids,
  656. attention_mask=extended_attention_mask,
  657. patch_height=patch_height,
  658. patch_width=patch_width,
  659. **kwargs,
  660. )
  661. sequence_output = encoder_outputs.last_hidden_state
  662. return BaseModelOutput(
  663. last_hidden_state=sequence_output,
  664. )
  665. class LayoutLMv3ClassificationHead(nn.Module):
  666. """
  667. Head for sentence-level classification tasks. Reference: RobertaClassificationHead
  668. """
  669. def __init__(self, config, pool_feature=False):
  670. super().__init__()
  671. self.pool_feature = pool_feature
  672. if pool_feature:
  673. self.dense = nn.Linear(config.hidden_size * 3, config.hidden_size)
  674. else:
  675. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  676. classifier_dropout = (
  677. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  678. )
  679. self.dropout = nn.Dropout(classifier_dropout)
  680. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  681. def forward(self, x):
  682. x = self.dropout(x)
  683. x = self.dense(x)
  684. x = torch.tanh(x)
  685. x = self.dropout(x)
  686. x = self.out_proj(x)
  687. return x
  688. @auto_docstring(
  689. custom_intro="""
  690. LayoutLMv3 Model with a token classification head on top (a linear layer on top of the final hidden states) e.g.
  691. for sequence labeling (information extraction) tasks such as [FUNSD](https://guillaumejaume.github.io/FUNSD/),
  692. [SROIE](https://rrc.cvc.uab.es/?ch=13), [CORD](https://github.com/clovaai/cord) and
  693. [Kleister-NDA](https://github.com/applicaai/kleister-nda).
  694. """
  695. )
  696. class LayoutLMv3ForTokenClassification(LayoutLMv3PreTrainedModel):
  697. def __init__(self, config):
  698. super().__init__(config)
  699. self.num_labels = config.num_labels
  700. self.layoutlmv3 = LayoutLMv3Model(config)
  701. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  702. if config.num_labels < 10:
  703. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  704. else:
  705. self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
  706. self.post_init()
  707. def get_input_embeddings(self):
  708. return self.layoutlmv3.get_input_embeddings()
  709. def set_input_embeddings(self, value):
  710. self.layoutlmv3.set_input_embeddings(value)
  711. @can_return_tuple
  712. @auto_docstring
  713. def forward(
  714. self,
  715. input_ids: torch.LongTensor | None = None,
  716. bbox: torch.LongTensor | None = None,
  717. attention_mask: torch.FloatTensor | None = None,
  718. token_type_ids: torch.LongTensor | None = None,
  719. position_ids: torch.LongTensor | None = None,
  720. inputs_embeds: torch.FloatTensor | None = None,
  721. labels: torch.LongTensor | None = None,
  722. pixel_values: torch.LongTensor | None = None,
  723. **kwargs: Unpack[TransformersKwargs],
  724. ) -> tuple | TokenClassifierOutput:
  725. r"""
  726. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  727. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  728. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  729. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  730. y1) represents the position of the lower right corner.
  731. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  732. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  733. Examples:
  734. ```python
  735. >>> from transformers import AutoProcessor, AutoModelForTokenClassification
  736. >>> from datasets import load_dataset
  737. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  738. >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", num_labels=7)
  739. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  740. >>> example = dataset[0]
  741. >>> image = example["image"]
  742. >>> words = example["tokens"]
  743. >>> boxes = example["bboxes"]
  744. >>> word_labels = example["ner_tags"]
  745. >>> encoding = processor(image, words, boxes=boxes, word_labels=word_labels, return_tensors="pt")
  746. >>> outputs = model(**encoding)
  747. >>> loss = outputs.loss
  748. >>> logits = outputs.logits
  749. ```"""
  750. outputs = self.layoutlmv3(
  751. input_ids,
  752. bbox=bbox,
  753. attention_mask=attention_mask,
  754. token_type_ids=token_type_ids,
  755. position_ids=position_ids,
  756. inputs_embeds=inputs_embeds,
  757. pixel_values=pixel_values,
  758. **kwargs,
  759. )
  760. if input_ids is not None:
  761. input_shape = input_ids.size()
  762. else:
  763. input_shape = inputs_embeds.size()[:-1]
  764. seq_length = input_shape[1]
  765. # only take the text part of the output representations
  766. sequence_output = outputs[0][:, :seq_length]
  767. sequence_output = self.dropout(sequence_output)
  768. logits = self.classifier(sequence_output)
  769. loss = None
  770. if labels is not None:
  771. loss_fct = CrossEntropyLoss()
  772. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  773. return TokenClassifierOutput(
  774. loss=loss,
  775. logits=logits,
  776. hidden_states=outputs.hidden_states,
  777. attentions=outputs.attentions,
  778. )
  779. @auto_docstring
  780. class LayoutLMv3ForQuestionAnswering(LayoutLMv3PreTrainedModel):
  781. def __init__(self, config):
  782. super().__init__(config)
  783. self.num_labels = config.num_labels
  784. self.layoutlmv3 = LayoutLMv3Model(config)
  785. self.qa_outputs = LayoutLMv3ClassificationHead(config, pool_feature=False)
  786. self.post_init()
  787. def get_input_embeddings(self):
  788. return self.layoutlmv3.get_input_embeddings()
  789. def set_input_embeddings(self, value):
  790. self.layoutlmv3.set_input_embeddings(value)
  791. @can_return_tuple
  792. @auto_docstring
  793. def forward(
  794. self,
  795. input_ids: torch.LongTensor | None = None,
  796. attention_mask: torch.FloatTensor | None = None,
  797. token_type_ids: torch.LongTensor | None = None,
  798. position_ids: torch.LongTensor | None = None,
  799. inputs_embeds: torch.FloatTensor | None = None,
  800. start_positions: torch.LongTensor | None = None,
  801. end_positions: torch.LongTensor | None = None,
  802. bbox: torch.LongTensor | None = None,
  803. pixel_values: torch.LongTensor | None = None,
  804. **kwargs: Unpack[TransformersKwargs],
  805. ) -> tuple | QuestionAnsweringModelOutput:
  806. r"""
  807. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  808. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  809. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  810. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  811. y1) represents the position of the lower right corner.
  812. Examples:
  813. ```python
  814. >>> from transformers import AutoProcessor, AutoModelForQuestionAnswering
  815. >>> from datasets import load_dataset
  816. >>> import torch
  817. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  818. >>> model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv3-base")
  819. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  820. >>> example = dataset[0]
  821. >>> image = example["image"]
  822. >>> question = "what's his name?"
  823. >>> words = example["tokens"]
  824. >>> boxes = example["bboxes"]
  825. >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt")
  826. >>> start_positions = torch.tensor([1])
  827. >>> end_positions = torch.tensor([3])
  828. >>> outputs = model(**encoding, start_positions=start_positions, end_positions=end_positions)
  829. >>> loss = outputs.loss
  830. >>> start_scores = outputs.start_logits
  831. >>> end_scores = outputs.end_logits
  832. ```"""
  833. outputs: BaseModelOutput = self.layoutlmv3(
  834. input_ids,
  835. attention_mask=attention_mask,
  836. token_type_ids=token_type_ids,
  837. position_ids=position_ids,
  838. inputs_embeds=inputs_embeds,
  839. bbox=bbox,
  840. pixel_values=pixel_values,
  841. **kwargs,
  842. )
  843. sequence_output = outputs[0]
  844. logits = self.qa_outputs(sequence_output)
  845. start_logits, end_logits = logits.split(1, dim=-1)
  846. start_logits = start_logits.squeeze(-1).contiguous()
  847. end_logits = end_logits.squeeze(-1).contiguous()
  848. total_loss = None
  849. if start_positions is not None and end_positions is not None:
  850. # If we are on multi-GPU, split add a dimension
  851. if len(start_positions.size()) > 1:
  852. start_positions = start_positions.squeeze(-1)
  853. if len(end_positions.size()) > 1:
  854. end_positions = end_positions.squeeze(-1)
  855. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  856. ignored_index = start_logits.size(1)
  857. start_positions = start_positions.clamp(0, ignored_index)
  858. end_positions = end_positions.clamp(0, ignored_index)
  859. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  860. start_loss = loss_fct(start_logits, start_positions)
  861. end_loss = loss_fct(end_logits, end_positions)
  862. total_loss = (start_loss + end_loss) / 2
  863. return QuestionAnsweringModelOutput(
  864. loss=total_loss,
  865. start_logits=start_logits,
  866. end_logits=end_logits,
  867. hidden_states=outputs.hidden_states,
  868. attentions=outputs.attentions,
  869. )
  870. @auto_docstring(
  871. custom_intro="""
  872. LayoutLMv3 Model with a sequence classification head on top (a linear layer on top of the final hidden state of the
  873. [CLS] token) e.g. for document image classification tasks such as the
  874. [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  875. """
  876. )
  877. class LayoutLMv3ForSequenceClassification(LayoutLMv3PreTrainedModel):
  878. def __init__(self, config):
  879. super().__init__(config)
  880. self.num_labels = config.num_labels
  881. self.config = config
  882. self.layoutlmv3 = LayoutLMv3Model(config)
  883. self.classifier = LayoutLMv3ClassificationHead(config, pool_feature=False)
  884. self.post_init()
  885. def get_input_embeddings(self):
  886. return self.layoutlmv3.get_input_embeddings()
  887. def set_input_embeddings(self, value):
  888. self.layoutlmv3.set_input_embeddings(value)
  889. @can_return_tuple
  890. @auto_docstring
  891. def forward(
  892. self,
  893. input_ids: torch.LongTensor | None = None,
  894. attention_mask: torch.FloatTensor | None = None,
  895. token_type_ids: torch.LongTensor | None = None,
  896. position_ids: torch.LongTensor | None = None,
  897. inputs_embeds: torch.FloatTensor | None = None,
  898. labels: torch.LongTensor | None = None,
  899. bbox: torch.LongTensor | None = None,
  900. pixel_values: torch.LongTensor | None = None,
  901. **kwargs: Unpack[TransformersKwargs],
  902. ) -> tuple | SequenceClassifierOutput:
  903. r"""
  904. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  905. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  906. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  907. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  908. y1) represents the position of the lower right corner.
  909. Examples:
  910. ```python
  911. >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
  912. >>> from datasets import load_dataset
  913. >>> import torch
  914. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
  915. >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/layoutlmv3-base")
  916. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  917. >>> example = dataset[0]
  918. >>> image = example["image"]
  919. >>> words = example["tokens"]
  920. >>> boxes = example["bboxes"]
  921. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  922. >>> sequence_label = torch.tensor([1])
  923. >>> outputs = model(**encoding, labels=sequence_label)
  924. >>> loss = outputs.loss
  925. >>> logits = outputs.logits
  926. ```"""
  927. outputs: BaseModelOutput = self.layoutlmv3(
  928. input_ids,
  929. attention_mask=attention_mask,
  930. token_type_ids=token_type_ids,
  931. position_ids=position_ids,
  932. inputs_embeds=inputs_embeds,
  933. bbox=bbox,
  934. pixel_values=pixel_values,
  935. **kwargs,
  936. )
  937. sequence_output = outputs[0][:, 0, :]
  938. logits = self.classifier(sequence_output)
  939. loss = None
  940. if labels is not None:
  941. if self.config.problem_type is None:
  942. if self.num_labels == 1:
  943. self.config.problem_type = "regression"
  944. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  945. self.config.problem_type = "single_label_classification"
  946. else:
  947. self.config.problem_type = "multi_label_classification"
  948. if self.config.problem_type == "regression":
  949. loss_fct = MSELoss()
  950. if self.num_labels == 1:
  951. loss = loss_fct(logits.squeeze(), labels.squeeze())
  952. else:
  953. loss = loss_fct(logits, labels)
  954. elif self.config.problem_type == "single_label_classification":
  955. loss_fct = CrossEntropyLoss()
  956. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  957. elif self.config.problem_type == "multi_label_classification":
  958. loss_fct = BCEWithLogitsLoss()
  959. loss = loss_fct(logits, labels)
  960. return SequenceClassifierOutput(
  961. loss=loss,
  962. logits=logits,
  963. hidden_states=outputs.hidden_states,
  964. attentions=outputs.attentions,
  965. )
  966. __all__ = [
  967. "LayoutLMv3ForQuestionAnswering",
  968. "LayoutLMv3ForSequenceClassification",
  969. "LayoutLMv3ForTokenClassification",
  970. "LayoutLMv3Model",
  971. "LayoutLMv3PreTrainedModel",
  972. ]