modeling_layoutlmv2.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. # Copyright 2021 Microsoft Research The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch LayoutLMv2 model."""
  15. import math
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. QuestionAnsweringModelOutput,
  26. SequenceClassifierOutput,
  27. TokenClassifierOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...pytorch_utils import apply_chunking_to_forward
  32. from ...utils import auto_docstring, is_detectron2_available, logging, requires_backends
  33. from ...utils.generic import TransformersKwargs, can_return_tuple, merge_with_config_defaults
  34. from ...utils.output_capturing import capture_outputs
  35. from .configuration_layoutlmv2 import LayoutLMv2Config
  36. # soft dependency
  37. if is_detectron2_available():
  38. import detectron2
  39. from detectron2.modeling import META_ARCH_REGISTRY
  40. # This is needed as otherwise their overload will break sequential loading by overwriting buffer over and over. See
  41. # https://github.com/facebookresearch/detectron2/blob/9604f5995cc628619f0e4fd913453b4d7d61db3f/detectron2/layers/batch_norm.py#L83-L86
  42. detectron2.layers.batch_norm.FrozenBatchNorm2d._load_from_state_dict = torch.nn.Module._load_from_state_dict
  43. logger = logging.get_logger(__name__)
  44. class LayoutLMv2Embeddings(nn.Module):
  45. """Construct the embeddings from word, position and token_type embeddings."""
  46. def __init__(self, config):
  47. super().__init__()
  48. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  49. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  50. self.x_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  51. self.y_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.coordinate_size)
  52. self.h_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  53. self.w_position_embeddings = nn.Embedding(config.max_2d_position_embeddings, config.shape_size)
  54. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  55. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  56. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  57. self.register_buffer(
  58. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  59. )
  60. def _calc_spatial_position_embeddings(self, bbox):
  61. try:
  62. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  63. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  64. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  65. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  66. except IndexError as e:
  67. raise IndexError("The `bbox` coordinate values should be within 0-1000 range.") from e
  68. h_position_embeddings = self.h_position_embeddings(bbox[:, :, 3] - bbox[:, :, 1])
  69. w_position_embeddings = self.w_position_embeddings(bbox[:, :, 2] - bbox[:, :, 0])
  70. spatial_position_embeddings = torch.cat(
  71. [
  72. left_position_embeddings,
  73. upper_position_embeddings,
  74. right_position_embeddings,
  75. lower_position_embeddings,
  76. h_position_embeddings,
  77. w_position_embeddings,
  78. ],
  79. dim=-1,
  80. )
  81. return spatial_position_embeddings
  82. class LayoutLMv2SelfAttention(nn.Module):
  83. def __init__(self, config):
  84. super().__init__()
  85. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  86. raise ValueError(
  87. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  88. f"heads ({config.num_attention_heads})"
  89. )
  90. self.fast_qkv = config.fast_qkv
  91. self.num_attention_heads = config.num_attention_heads
  92. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  93. self.all_head_size = self.num_attention_heads * self.attention_head_size
  94. self.has_relative_attention_bias = config.has_relative_attention_bias
  95. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  96. if config.fast_qkv:
  97. self.qkv_linear = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=False)
  98. self.q_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
  99. self.v_bias = nn.Parameter(torch.zeros(1, 1, self.all_head_size))
  100. else:
  101. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  102. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  103. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  104. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  105. def compute_qkv(self, hidden_states):
  106. if self.fast_qkv:
  107. qkv = self.qkv_linear(hidden_states)
  108. q, k, v = torch.chunk(qkv, 3, dim=-1)
  109. if q.ndimension() == self.q_bias.ndimension():
  110. q = q + self.q_bias
  111. v = v + self.v_bias
  112. else:
  113. _sz = (1,) * (q.ndimension() - 1) + (-1,)
  114. q = q + self.q_bias.view(*_sz)
  115. v = v + self.v_bias.view(*_sz)
  116. else:
  117. q = self.query(hidden_states)
  118. k = self.key(hidden_states)
  119. v = self.value(hidden_states)
  120. return q, k, v
  121. def forward(
  122. self,
  123. hidden_states,
  124. attention_mask=None,
  125. rel_pos=None,
  126. rel_2d_pos=None,
  127. **kwargs: Unpack[TransformersKwargs],
  128. ):
  129. batch_size = hidden_states.shape[0]
  130. query, key, value = self.compute_qkv(hidden_states)
  131. # (B, L, H*D) -> (B, H, L, D)
  132. query_layer = query.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  133. key_layer = key.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  134. value_layer = value.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  135. query_layer = query_layer / math.sqrt(self.attention_head_size)
  136. # [BSZ, NAT, L, L]
  137. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  138. if self.has_relative_attention_bias:
  139. attention_scores += rel_pos
  140. if self.has_spatial_attention_bias:
  141. attention_scores += rel_2d_pos
  142. attention_scores = attention_scores.float().masked_fill_(
  143. attention_mask.to(torch.bool), torch.finfo(attention_scores.dtype).min
  144. )
  145. attention_probs = nn.functional.softmax(attention_scores, dim=-1, dtype=torch.float32).type_as(value_layer)
  146. # This is actually dropping out entire tokens to attend to, which might
  147. # seem a bit unusual, but is taken from the original Transformer paper.
  148. attention_probs = self.dropout(attention_probs)
  149. context_layer = torch.matmul(attention_probs, value_layer)
  150. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  151. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  152. context_layer = context_layer.view(*new_context_layer_shape)
  153. return context_layer, attention_probs
  154. class LayoutLMv2Attention(nn.Module):
  155. def __init__(self, config):
  156. super().__init__()
  157. self.self = LayoutLMv2SelfAttention(config)
  158. self.output = LayoutLMv2SelfOutput(config)
  159. def forward(
  160. self,
  161. hidden_states,
  162. attention_mask=None,
  163. rel_pos=None,
  164. rel_2d_pos=None,
  165. **kwargs: Unpack[TransformersKwargs],
  166. ):
  167. residual = hidden_states
  168. attention_output, _ = self.self(
  169. hidden_states,
  170. attention_mask,
  171. rel_pos=rel_pos,
  172. rel_2d_pos=rel_2d_pos,
  173. **kwargs,
  174. )
  175. attention_output = self.output(attention_output, residual)
  176. return attention_output
  177. class LayoutLMv2SelfOutput(nn.Module):
  178. def __init__(self, config):
  179. super().__init__()
  180. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  181. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  182. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  183. def forward(self, hidden_states, input_tensor):
  184. hidden_states = self.dense(hidden_states)
  185. hidden_states = self.dropout(hidden_states)
  186. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  187. return hidden_states
  188. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->LayoutLMv2
  189. class LayoutLMv2Intermediate(nn.Module):
  190. def __init__(self, config):
  191. super().__init__()
  192. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  193. if isinstance(config.hidden_act, str):
  194. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  195. else:
  196. self.intermediate_act_fn = config.hidden_act
  197. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  198. hidden_states = self.dense(hidden_states)
  199. hidden_states = self.intermediate_act_fn(hidden_states)
  200. return hidden_states
  201. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->LayoutLM
  202. class LayoutLMv2Output(nn.Module):
  203. def __init__(self, config):
  204. super().__init__()
  205. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  206. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  207. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  208. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  209. hidden_states = self.dense(hidden_states)
  210. hidden_states = self.dropout(hidden_states)
  211. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  212. return hidden_states
  213. class LayoutLMv2Layer(GradientCheckpointingLayer):
  214. def __init__(self, config):
  215. super().__init__()
  216. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  217. self.seq_len_dim = 1
  218. self.attention = LayoutLMv2Attention(config)
  219. self.intermediate = LayoutLMv2Intermediate(config)
  220. self.output = LayoutLMv2Output(config)
  221. def forward(
  222. self,
  223. hidden_states,
  224. attention_mask=None,
  225. output_attentions=False,
  226. rel_pos=None,
  227. rel_2d_pos=None,
  228. **kwargs: Unpack[TransformersKwargs],
  229. ):
  230. attention_output = self.attention(
  231. hidden_states,
  232. attention_mask,
  233. rel_pos=rel_pos,
  234. rel_2d_pos=rel_2d_pos,
  235. )
  236. layer_output = apply_chunking_to_forward(
  237. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  238. )
  239. return layer_output
  240. def feed_forward_chunk(self, attention_output):
  241. intermediate_output = self.intermediate(attention_output)
  242. layer_output = self.output(intermediate_output, attention_output)
  243. return layer_output
  244. def relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  245. """
  246. Adapted from Mesh Tensorflow:
  247. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  248. Translate relative position to a bucket number for relative attention. The relative position is defined as
  249. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  250. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small
  251. absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions
  252. >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should
  253. allow for more graceful generalization to longer sequences than the model has been trained on.
  254. Args:
  255. relative_position: an int32 Tensor
  256. bidirectional: a boolean - whether the attention is bidirectional
  257. num_buckets: an integer
  258. max_distance: an integer
  259. Returns:
  260. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  261. """
  262. ret = 0
  263. if bidirectional:
  264. num_buckets //= 2
  265. ret += (relative_position > 0).long() * num_buckets
  266. n = torch.abs(relative_position)
  267. else:
  268. n = torch.max(-relative_position, torch.zeros_like(relative_position))
  269. # now n is in the range [0, inf)
  270. # half of the buckets are for exact increments in positions
  271. max_exact = num_buckets // 2
  272. is_small = n < max_exact
  273. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  274. val_if_large = max_exact + (
  275. torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
  276. ).to(torch.long)
  277. val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
  278. ret += torch.where(is_small, n, val_if_large)
  279. return ret
  280. class LayoutLMv2Encoder(nn.Module):
  281. def __init__(self, config):
  282. super().__init__()
  283. self.config = config
  284. self.layer = nn.ModuleList([LayoutLMv2Layer(config) for _ in range(config.num_hidden_layers)])
  285. self.has_relative_attention_bias = config.has_relative_attention_bias
  286. self.has_spatial_attention_bias = config.has_spatial_attention_bias
  287. if self.has_relative_attention_bias:
  288. self.rel_pos_bins = config.rel_pos_bins
  289. self.max_rel_pos = config.max_rel_pos
  290. self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
  291. if self.has_spatial_attention_bias:
  292. self.max_rel_2d_pos = config.max_rel_2d_pos
  293. self.rel_2d_pos_bins = config.rel_2d_pos_bins
  294. self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  295. self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
  296. self.gradient_checkpointing = False
  297. def _calculate_1d_position_embeddings(self, position_ids):
  298. rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
  299. rel_pos = relative_position_bucket(
  300. rel_pos_mat,
  301. num_buckets=self.rel_pos_bins,
  302. max_distance=self.max_rel_pos,
  303. )
  304. # Since this is a simple indexing operation that is independent of the input,
  305. # no need to track gradients for this operation
  306. #
  307. # Without this no_grad context, training speed slows down significantly
  308. with torch.no_grad():
  309. rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
  310. rel_pos = rel_pos.contiguous()
  311. return rel_pos
  312. def _calculate_2d_position_embeddings(self, bbox):
  313. position_coord_x = bbox[:, :, 0]
  314. position_coord_y = bbox[:, :, 3]
  315. rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
  316. rel_pos_y_2d_mat = position_coord_y.unsqueeze(-2) - position_coord_y.unsqueeze(-1)
  317. rel_pos_x = relative_position_bucket(
  318. rel_pos_x_2d_mat,
  319. num_buckets=self.rel_2d_pos_bins,
  320. max_distance=self.max_rel_2d_pos,
  321. )
  322. rel_pos_y = relative_position_bucket(
  323. rel_pos_y_2d_mat,
  324. num_buckets=self.rel_2d_pos_bins,
  325. max_distance=self.max_rel_2d_pos,
  326. )
  327. # Since this is a simple indexing operation that is independent of the input,
  328. # no need to track gradients for this operation
  329. #
  330. # Without this no_grad context, training speed slows down significantly
  331. with torch.no_grad():
  332. rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
  333. rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
  334. rel_pos_x = rel_pos_x.contiguous()
  335. rel_pos_y = rel_pos_y.contiguous()
  336. rel_2d_pos = rel_pos_x + rel_pos_y
  337. return rel_2d_pos
  338. def forward(
  339. self,
  340. hidden_states,
  341. attention_mask=None,
  342. bbox=None,
  343. position_ids=None,
  344. **kwargs: Unpack[TransformersKwargs],
  345. ):
  346. rel_pos = self._calculate_1d_position_embeddings(position_ids) if self.has_relative_attention_bias else None
  347. rel_2d_pos = self._calculate_2d_position_embeddings(bbox) if self.has_spatial_attention_bias else None
  348. for layer_module in self.layer:
  349. hidden_states = layer_module(
  350. hidden_states,
  351. attention_mask,
  352. rel_pos=rel_pos,
  353. rel_2d_pos=rel_2d_pos,
  354. **kwargs,
  355. )
  356. return BaseModelOutput(last_hidden_state=hidden_states)
  357. @auto_docstring
  358. class LayoutLMv2PreTrainedModel(PreTrainedModel):
  359. config: LayoutLMv2Config
  360. base_model_prefix = "layoutlmv2"
  361. input_modalities = ("image", "text")
  362. @torch.no_grad()
  363. def _init_weights(self, module):
  364. """Initialize the weights"""
  365. super()._init_weights(module)
  366. if isinstance(module, LayoutLMv2SelfAttention):
  367. if self.config.fast_qkv:
  368. init.zeros_(module.q_bias)
  369. init.zeros_(module.v_bias)
  370. elif isinstance(module, LayoutLMv2Embeddings):
  371. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  372. elif isinstance(module, LayoutLMv2VisualBackbone):
  373. num_channels = len(module.cfg.MODEL.PIXEL_MEAN)
  374. init.copy_(module.pixel_mean, torch.Tensor(module.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1))
  375. init.copy_(module.pixel_std, torch.Tensor(module.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1))
  376. elif isinstance(module, LayoutLMv2Model):
  377. if hasattr(module, "visual_segment_embedding"):
  378. init.normal_(module.visual_segment_embedding, mean=0.0, std=self.config.initializer_range)
  379. # We check the existence of each one since detectron2 seems to do weird things
  380. elif isinstance(module, detectron2.layers.FrozenBatchNorm2d):
  381. init.ones_(module.weight)
  382. init.zeros_(module.bias)
  383. init.zeros_(module.running_mean)
  384. init.constant_(module.running_var, 1.0 - module.eps)
  385. def my_convert_sync_batchnorm(module, process_group=None):
  386. # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
  387. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  388. return nn.modules.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  389. module_output = module
  390. if isinstance(module, detectron2.layers.FrozenBatchNorm2d):
  391. module_output = torch.nn.SyncBatchNorm(
  392. num_features=module.num_features,
  393. eps=module.eps,
  394. affine=True,
  395. track_running_stats=True,
  396. process_group=process_group,
  397. )
  398. module_output.weight = torch.nn.Parameter(module.weight)
  399. module_output.bias = torch.nn.Parameter(module.bias)
  400. module_output.running_mean = module.running_mean
  401. module_output.running_var = module.running_var
  402. module_output.num_batches_tracked = torch.tensor(0, dtype=torch.long, device=module.running_mean.device)
  403. for name, child in module.named_children():
  404. module_output.add_module(name, my_convert_sync_batchnorm(child, process_group))
  405. del module
  406. return module_output
  407. class LayoutLMv2VisualBackbone(nn.Module):
  408. def __init__(self, config):
  409. super().__init__()
  410. self.cfg = config.get_detectron2_config()
  411. meta_arch = self.cfg.MODEL.META_ARCHITECTURE
  412. model = META_ARCH_REGISTRY.get(meta_arch)(self.cfg)
  413. assert isinstance(model.backbone, detectron2.modeling.backbone.FPN)
  414. self.backbone = model.backbone
  415. assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD)
  416. num_channels = len(self.cfg.MODEL.PIXEL_MEAN)
  417. self.register_buffer(
  418. "pixel_mean",
  419. torch.Tensor(self.cfg.MODEL.PIXEL_MEAN).view(num_channels, 1, 1),
  420. persistent=False,
  421. )
  422. self.register_buffer(
  423. "pixel_std", torch.Tensor(self.cfg.MODEL.PIXEL_STD).view(num_channels, 1, 1), persistent=False
  424. )
  425. self.out_feature_key = "p2"
  426. if torch.are_deterministic_algorithms_enabled():
  427. logger.warning("using `AvgPool2d` instead of `AdaptiveAvgPool2d`")
  428. input_shape = (224, 224)
  429. backbone_stride = self.backbone.output_shape()[self.out_feature_key].stride
  430. self.pool = nn.AvgPool2d(
  431. (
  432. math.ceil(math.ceil(input_shape[0] / backbone_stride) / config.image_feature_pool_shape[0]),
  433. math.ceil(math.ceil(input_shape[1] / backbone_stride) / config.image_feature_pool_shape[1]),
  434. )
  435. )
  436. else:
  437. self.pool = nn.AdaptiveAvgPool2d(config.image_feature_pool_shape[:2])
  438. if len(config.image_feature_pool_shape) == 2:
  439. config.image_feature_pool_shape.append(self.backbone.output_shape()[self.out_feature_key].channels)
  440. assert self.backbone.output_shape()[self.out_feature_key].channels == config.image_feature_pool_shape[2]
  441. def forward(self, images):
  442. images_input = ((images if torch.is_tensor(images) else images.tensor) - self.pixel_mean) / self.pixel_std
  443. features = self.backbone(images_input)
  444. features = features[self.out_feature_key]
  445. features = self.pool(features).flatten(start_dim=2).transpose(1, 2).contiguous()
  446. return features
  447. def synchronize_batch_norm(self):
  448. if not (
  449. torch.distributed.is_available()
  450. and torch.distributed.is_initialized()
  451. and torch.distributed.get_rank() > -1
  452. ):
  453. raise RuntimeError("Make sure torch.distributed is set up properly.")
  454. self_rank = torch.distributed.get_rank()
  455. node_size = torch.cuda.device_count()
  456. world_size = torch.distributed.get_world_size()
  457. if not (world_size % node_size == 0):
  458. raise RuntimeError("Make sure the number of processes can be divided by the number of nodes")
  459. node_global_ranks = [list(range(i * node_size, (i + 1) * node_size)) for i in range(world_size // node_size)]
  460. sync_bn_groups = [
  461. torch.distributed.new_group(ranks=node_global_ranks[i]) for i in range(world_size // node_size)
  462. ]
  463. node_rank = self_rank // node_size
  464. self.backbone = my_convert_sync_batchnorm(self.backbone, process_group=sync_bn_groups[node_rank])
  465. class LayoutLMv2Pooler(nn.Module):
  466. def __init__(self, config):
  467. super().__init__()
  468. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  469. self.activation = nn.Tanh()
  470. def forward(self, hidden_states):
  471. # We "pool" the model by simply taking the hidden state corresponding
  472. # to the first token.
  473. first_token_tensor = hidden_states[:, 0]
  474. pooled_output = self.dense(first_token_tensor)
  475. pooled_output = self.activation(pooled_output)
  476. return pooled_output
  477. @auto_docstring
  478. class LayoutLMv2Model(LayoutLMv2PreTrainedModel):
  479. _can_record_outputs = {"hidden_states": LayoutLMv2Layer, "attentions": LayoutLMv2SelfAttention}
  480. def __init__(self, config):
  481. requires_backends(self, "detectron2")
  482. super().__init__(config)
  483. self.config = config
  484. self.has_visual_segment_embedding = config.has_visual_segment_embedding
  485. self.embeddings = LayoutLMv2Embeddings(config)
  486. self.visual = LayoutLMv2VisualBackbone(config)
  487. self.visual_proj = nn.Linear(config.image_feature_pool_shape[-1], config.hidden_size)
  488. if self.has_visual_segment_embedding:
  489. self.visual_segment_embedding = nn.Parameter(nn.Embedding(1, config.hidden_size).weight[0])
  490. self.visual_LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  491. self.visual_dropout = nn.Dropout(config.hidden_dropout_prob)
  492. self.encoder = LayoutLMv2Encoder(config)
  493. self.pooler = LayoutLMv2Pooler(config)
  494. # Initialize weights and apply final processing
  495. self.post_init()
  496. def get_input_embeddings(self):
  497. return self.embeddings.word_embeddings
  498. def set_input_embeddings(self, value):
  499. self.embeddings.word_embeddings = value
  500. def _calc_text_embeddings(self, input_ids, bbox, position_ids, token_type_ids, inputs_embeds=None):
  501. if input_ids is not None:
  502. input_shape = input_ids.size()
  503. else:
  504. input_shape = inputs_embeds.size()[:-1]
  505. seq_length = input_shape[1]
  506. if position_ids is None:
  507. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  508. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  509. if token_type_ids is None:
  510. token_type_ids = torch.zeros_like(input_ids)
  511. if inputs_embeds is None:
  512. inputs_embeds = self.embeddings.word_embeddings(input_ids)
  513. position_embeddings = self.embeddings.position_embeddings(position_ids)
  514. spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
  515. token_type_embeddings = self.embeddings.token_type_embeddings(token_type_ids)
  516. embeddings = inputs_embeds + position_embeddings + spatial_position_embeddings + token_type_embeddings
  517. embeddings = self.embeddings.LayerNorm(embeddings)
  518. embeddings = self.embeddings.dropout(embeddings)
  519. return embeddings
  520. def _calc_img_embeddings(self, image, bbox, position_ids):
  521. visual_embeddings = self.visual_proj(self.visual(image))
  522. position_embeddings = self.embeddings.position_embeddings(position_ids)
  523. spatial_position_embeddings = self.embeddings._calc_spatial_position_embeddings(bbox)
  524. embeddings = visual_embeddings + position_embeddings + spatial_position_embeddings
  525. if self.has_visual_segment_embedding:
  526. embeddings += self.visual_segment_embedding
  527. embeddings = self.visual_LayerNorm(embeddings)
  528. embeddings = self.visual_dropout(embeddings)
  529. return embeddings
  530. def _calc_visual_bbox(self, image_feature_pool_shape, bbox, device, final_shape):
  531. visual_bbox_x = torch.div(
  532. torch.arange(
  533. 0,
  534. 1000 * (image_feature_pool_shape[1] + 1),
  535. 1000,
  536. device=device,
  537. dtype=bbox.dtype,
  538. ),
  539. self.config.image_feature_pool_shape[1],
  540. rounding_mode="floor",
  541. )
  542. visual_bbox_y = torch.div(
  543. torch.arange(
  544. 0,
  545. 1000 * (self.config.image_feature_pool_shape[0] + 1),
  546. 1000,
  547. device=device,
  548. dtype=bbox.dtype,
  549. ),
  550. self.config.image_feature_pool_shape[0],
  551. rounding_mode="floor",
  552. )
  553. visual_bbox = torch.stack(
  554. [
  555. visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
  556. visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  557. visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
  558. visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  559. ],
  560. dim=-1,
  561. ).view(-1, bbox.size(-1))
  562. visual_bbox = visual_bbox.repeat(final_shape[0], 1, 1)
  563. return visual_bbox
  564. def _get_input_shape(self, input_ids=None, inputs_embeds=None):
  565. if input_ids is not None and inputs_embeds is not None:
  566. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  567. elif input_ids is not None:
  568. return input_ids.size()
  569. elif inputs_embeds is not None:
  570. return inputs_embeds.size()[:-1]
  571. else:
  572. raise ValueError("You have to specify either input_ids or inputs_embeds")
  573. @merge_with_config_defaults
  574. @capture_outputs
  575. @auto_docstring
  576. def forward(
  577. self,
  578. input_ids: torch.LongTensor | None = None,
  579. bbox: torch.LongTensor | None = None,
  580. image: torch.FloatTensor | None = None,
  581. attention_mask: torch.FloatTensor | None = None,
  582. token_type_ids: torch.LongTensor | None = None,
  583. position_ids: torch.LongTensor | None = None,
  584. inputs_embeds: torch.FloatTensor | None = None,
  585. **kwargs: Unpack[TransformersKwargs],
  586. ) -> tuple | BaseModelOutputWithPooling:
  587. r"""
  588. bbox (`torch.LongTensor` of shape `((batch_size, sequence_length), 4)`, *optional*):
  589. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  590. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  591. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  592. y1) represents the position of the lower right corner.
  593. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  594. Batch of document images.
  595. Examples:
  596. ```python
  597. >>> from transformers import AutoProcessor, LayoutLMv2Model, set_seed
  598. >>> from PIL import Image
  599. >>> import torch
  600. >>> from datasets import load_dataset
  601. >>> set_seed(0)
  602. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  603. >>> model = LayoutLMv2Model.from_pretrained("microsoft/layoutlmv2-base-uncased")
  604. >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
  605. >>> image = dataset["test"][0]["image"]
  606. >>> encoding = processor(image, return_tensors="pt")
  607. >>> outputs = model(**encoding)
  608. >>> last_hidden_states = outputs.last_hidden_state
  609. >>> last_hidden_states.shape
  610. torch.Size([1, 342, 768])
  611. ```
  612. """
  613. input_shape = self._get_input_shape(input_ids, inputs_embeds)
  614. device = input_ids.device if input_ids is not None else inputs_embeds.device
  615. visual_shape = list(input_shape)
  616. visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
  617. visual_shape = torch.Size(visual_shape)
  618. # needs a new copy of input_shape for tracing. Otherwise wrong dimensions will occur
  619. final_shape = list(self._get_input_shape(input_ids, inputs_embeds))
  620. final_shape[1] += visual_shape[1]
  621. final_shape = torch.Size(final_shape)
  622. visual_bbox = self._calc_visual_bbox(self.config.image_feature_pool_shape, bbox, device, final_shape)
  623. final_bbox = torch.cat([bbox, visual_bbox], dim=1)
  624. if attention_mask is None:
  625. attention_mask = torch.ones(input_shape, device=device)
  626. visual_attention_mask = torch.ones(visual_shape, device=device)
  627. final_attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
  628. if token_type_ids is None:
  629. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  630. if position_ids is None:
  631. seq_length = input_shape[1]
  632. position_ids = self.embeddings.position_ids[:, :seq_length]
  633. position_ids = position_ids.expand(input_shape)
  634. visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
  635. input_shape[0], 1
  636. )
  637. final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
  638. if bbox is None:
  639. bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)
  640. text_layout_emb = self._calc_text_embeddings(
  641. input_ids=input_ids,
  642. bbox=bbox,
  643. token_type_ids=token_type_ids,
  644. position_ids=position_ids,
  645. inputs_embeds=inputs_embeds,
  646. )
  647. visual_emb = self._calc_img_embeddings(
  648. image=image,
  649. bbox=visual_bbox,
  650. position_ids=visual_position_ids,
  651. )
  652. final_emb = torch.cat([text_layout_emb, visual_emb], dim=1)
  653. extended_attention_mask = final_attention_mask.unsqueeze(1).unsqueeze(2)
  654. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  655. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  656. encoder_outputs = self.encoder(
  657. final_emb,
  658. extended_attention_mask,
  659. bbox=final_bbox,
  660. position_ids=final_position_ids,
  661. **kwargs,
  662. )
  663. sequence_output = encoder_outputs.last_hidden_state
  664. pooled_output = self.pooler(sequence_output)
  665. return BaseModelOutputWithPooling(
  666. last_hidden_state=sequence_output,
  667. pooler_output=pooled_output,
  668. )
  669. @auto_docstring(
  670. custom_intro="""
  671. LayoutLMv2 Model with a sequence classification head on top (a linear layer on top of the concatenation of the
  672. final hidden state of the [CLS] token, average-pooled initial visual embeddings and average-pooled final visual
  673. embeddings, e.g. for document image classification tasks such as the
  674. [RVL-CDIP](https://www.cs.cmu.edu/~aharley/rvl-cdip/) dataset.
  675. """
  676. )
  677. class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
  678. def __init__(self, config):
  679. super().__init__(config)
  680. self.num_labels = config.num_labels
  681. self.layoutlmv2 = LayoutLMv2Model(config)
  682. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  683. self.classifier = nn.Linear(config.hidden_size * 3, config.num_labels)
  684. # Initialize weights and apply final processing
  685. self.post_init()
  686. def get_input_embeddings(self):
  687. return self.layoutlmv2.embeddings.word_embeddings
  688. @can_return_tuple
  689. @auto_docstring
  690. def forward(
  691. self,
  692. input_ids: torch.LongTensor | None = None,
  693. bbox: torch.LongTensor | None = None,
  694. image: torch.FloatTensor | None = None,
  695. attention_mask: torch.FloatTensor | None = None,
  696. token_type_ids: torch.LongTensor | None = None,
  697. position_ids: torch.LongTensor | None = None,
  698. inputs_embeds: torch.FloatTensor | None = None,
  699. labels: torch.LongTensor | None = None,
  700. **kwargs: Unpack[TransformersKwargs],
  701. ) -> tuple | SequenceClassifierOutput:
  702. r"""
  703. input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
  704. Indices of input sequence tokens in the vocabulary.
  705. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  706. [`PreTrainedTokenizer.__call__`] for details.
  707. [What are input IDs?](../glossary#input-ids)
  708. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  709. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  710. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  711. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  712. y1) represents the position of the lower right corner.
  713. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  714. Batch of document images.
  715. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  716. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  717. 1]`:
  718. - 0 corresponds to a *sentence A* token,
  719. - 1 corresponds to a *sentence B* token.
  720. [What are token type IDs?](../glossary#token-type-ids)
  721. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  722. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  723. config.max_position_embeddings - 1]`.
  724. [What are position IDs?](../glossary#position-ids)
  725. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  726. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  727. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  728. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  729. Example:
  730. ```python
  731. >>> from transformers import AutoProcessor, LayoutLMv2ForSequenceClassification, set_seed
  732. >>> from PIL import Image
  733. >>> import torch
  734. >>> from datasets import load_dataset
  735. >>> set_seed(0)
  736. >>> dataset = load_dataset("aharley/rvl_cdip", split="train", streaming=True)
  737. >>> data = next(iter(dataset))
  738. >>> image = data["image"].convert("RGB")
  739. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  740. >>> model = LayoutLMv2ForSequenceClassification.from_pretrained(
  741. ... "microsoft/layoutlmv2-base-uncased", num_labels=dataset.info.features["label"].num_classes
  742. ... )
  743. >>> encoding = processor(image, return_tensors="pt")
  744. >>> sequence_label = torch.tensor([data["label"]])
  745. >>> outputs = model(**encoding, labels=sequence_label)
  746. >>> loss, logits = outputs.loss, outputs.logits
  747. >>> predicted_idx = logits.argmax(dim=-1).item()
  748. >>> predicted_answer = dataset.info.features["label"].names[4]
  749. >>> predicted_idx, predicted_answer # results are not good without further fine-tuning
  750. (7, 'advertisement')
  751. ```
  752. """
  753. if input_ids is not None and inputs_embeds is not None:
  754. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  755. elif input_ids is not None:
  756. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  757. input_shape = input_ids.size()
  758. elif inputs_embeds is not None:
  759. input_shape = inputs_embeds.size()[:-1]
  760. else:
  761. raise ValueError("You have to specify either input_ids or inputs_embeds")
  762. device = input_ids.device if input_ids is not None else inputs_embeds.device
  763. visual_shape = list(input_shape)
  764. visual_shape[1] = self.config.image_feature_pool_shape[0] * self.config.image_feature_pool_shape[1]
  765. visual_shape = torch.Size(visual_shape)
  766. final_shape = list(input_shape)
  767. final_shape[1] += visual_shape[1]
  768. final_shape = torch.Size(final_shape)
  769. visual_bbox = self.layoutlmv2._calc_visual_bbox(
  770. self.config.image_feature_pool_shape, bbox, device, final_shape
  771. )
  772. visual_position_ids = torch.arange(0, visual_shape[1], dtype=torch.long, device=device).repeat(
  773. input_shape[0], 1
  774. )
  775. initial_image_embeddings = self.layoutlmv2._calc_img_embeddings(
  776. image=image,
  777. bbox=visual_bbox,
  778. position_ids=visual_position_ids,
  779. )
  780. outputs: BaseModelOutputWithPooling = self.layoutlmv2(
  781. input_ids=input_ids,
  782. bbox=bbox,
  783. image=image,
  784. attention_mask=attention_mask,
  785. token_type_ids=token_type_ids,
  786. position_ids=position_ids,
  787. inputs_embeds=inputs_embeds,
  788. **kwargs,
  789. )
  790. if input_ids is not None:
  791. input_shape = input_ids.size()
  792. else:
  793. input_shape = inputs_embeds.size()[:-1]
  794. seq_length = input_shape[1]
  795. sequence_output, final_image_embeddings = (
  796. outputs.last_hidden_state[:, :seq_length],
  797. outputs.last_hidden_state[:, seq_length:],
  798. )
  799. cls_final_output = sequence_output[:, 0, :]
  800. # average-pool the visual embeddings
  801. pooled_initial_image_embeddings = initial_image_embeddings.mean(dim=1)
  802. pooled_final_image_embeddings = final_image_embeddings.mean(dim=1)
  803. # concatenate with cls_final_output
  804. sequence_output = torch.cat(
  805. [cls_final_output, pooled_initial_image_embeddings, pooled_final_image_embeddings], dim=1
  806. )
  807. sequence_output = self.dropout(sequence_output)
  808. logits = self.classifier(sequence_output)
  809. loss = None
  810. if labels is not None:
  811. if self.config.problem_type is None:
  812. if self.num_labels == 1:
  813. self.config.problem_type = "regression"
  814. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  815. self.config.problem_type = "single_label_classification"
  816. else:
  817. self.config.problem_type = "multi_label_classification"
  818. if self.config.problem_type == "regression":
  819. loss_fct = MSELoss()
  820. if self.num_labels == 1:
  821. loss = loss_fct(logits.squeeze(), labels.squeeze())
  822. else:
  823. loss = loss_fct(logits, labels)
  824. elif self.config.problem_type == "single_label_classification":
  825. loss_fct = CrossEntropyLoss()
  826. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  827. elif self.config.problem_type == "multi_label_classification":
  828. loss_fct = BCEWithLogitsLoss()
  829. loss = loss_fct(logits, labels)
  830. return SequenceClassifierOutput(
  831. loss=loss,
  832. logits=logits,
  833. hidden_states=outputs.hidden_states,
  834. attentions=outputs.attentions,
  835. )
  836. @auto_docstring(
  837. custom_intro="""
  838. LayoutLMv2 Model with a token classification head on top (a linear layer on top of the text part of the hidden
  839. states) e.g. for sequence labeling (information extraction) tasks such as
  840. [FUNSD](https://guillaumejaume.github.io/FUNSD/), [SROIE](https://rrc.cvc.uab.es/?ch=13),
  841. [CORD](https://github.com/clovaai/cord) and [Kleister-NDA](https://github.com/applicaai/kleister-nda).
  842. """
  843. )
  844. class LayoutLMv2ForTokenClassification(LayoutLMv2PreTrainedModel):
  845. def __init__(self, config):
  846. super().__init__(config)
  847. self.num_labels = config.num_labels
  848. self.layoutlmv2 = LayoutLMv2Model(config)
  849. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  850. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  851. # Initialize weights and apply final processing
  852. self.post_init()
  853. def get_input_embeddings(self):
  854. return self.layoutlmv2.embeddings.word_embeddings
  855. @can_return_tuple
  856. @auto_docstring
  857. def forward(
  858. self,
  859. input_ids: torch.LongTensor | None = None,
  860. bbox: torch.LongTensor | None = None,
  861. image: torch.FloatTensor | None = None,
  862. attention_mask: torch.FloatTensor | None = None,
  863. token_type_ids: torch.LongTensor | None = None,
  864. position_ids: torch.LongTensor | None = None,
  865. inputs_embeds: torch.FloatTensor | None = None,
  866. labels: torch.LongTensor | None = None,
  867. **kwargs: Unpack[TransformersKwargs],
  868. ) -> tuple | TokenClassifierOutput:
  869. r"""
  870. input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
  871. Indices of input sequence tokens in the vocabulary.
  872. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  873. [`PreTrainedTokenizer.__call__`] for details.
  874. [What are input IDs?](../glossary#input-ids)
  875. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  876. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  877. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  878. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  879. y1) represents the position of the lower right corner.
  880. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  881. Batch of document images.
  882. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  883. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  884. 1]`:
  885. - 0 corresponds to a *sentence A* token,
  886. - 1 corresponds to a *sentence B* token.
  887. [What are token type IDs?](../glossary#token-type-ids)
  888. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  889. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  890. config.max_position_embeddings - 1]`.
  891. [What are position IDs?](../glossary#position-ids)
  892. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  893. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  894. Example:
  895. ```python
  896. >>> from transformers import AutoProcessor, LayoutLMv2ForTokenClassification, set_seed
  897. >>> from PIL import Image
  898. >>> from datasets import load_dataset
  899. >>> set_seed(0)
  900. >>> datasets = load_dataset("nielsr/funsd", split="test")
  901. >>> labels = datasets.features["ner_tags"].feature.names
  902. >>> id2label = {v: k for v, k in enumerate(labels)}
  903. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased", revision="no_ocr")
  904. >>> model = LayoutLMv2ForTokenClassification.from_pretrained(
  905. ... "microsoft/layoutlmv2-base-uncased", num_labels=len(labels)
  906. ... )
  907. >>> data = datasets[0]
  908. >>> image = Image.open(data["image_path"]).convert("RGB")
  909. >>> words = data["words"]
  910. >>> boxes = data["bboxes"] # make sure to normalize your bounding boxes
  911. >>> word_labels = data["ner_tags"]
  912. >>> encoding = processor(
  913. ... image,
  914. ... words,
  915. ... boxes=boxes,
  916. ... word_labels=word_labels,
  917. ... padding="max_length",
  918. ... truncation=True,
  919. ... return_tensors="pt",
  920. ... )
  921. >>> outputs = model(**encoding)
  922. >>> logits, loss = outputs.logits, outputs.loss
  923. >>> predicted_token_class_ids = logits.argmax(-1)
  924. >>> predicted_tokens_classes = [id2label[t.item()] for t in predicted_token_class_ids[0]]
  925. >>> predicted_tokens_classes[:5] # results are not good without further fine-tuning
  926. ['I-HEADER', 'I-HEADER', 'I-QUESTION', 'I-HEADER', 'I-QUESTION']
  927. ```
  928. """
  929. outputs: BaseModelOutputWithPooling = self.layoutlmv2(
  930. input_ids=input_ids,
  931. bbox=bbox,
  932. image=image,
  933. attention_mask=attention_mask,
  934. token_type_ids=token_type_ids,
  935. position_ids=position_ids,
  936. inputs_embeds=inputs_embeds,
  937. **kwargs,
  938. )
  939. if input_ids is not None:
  940. input_shape = input_ids.size()
  941. else:
  942. input_shape = inputs_embeds.size()[:-1]
  943. seq_length = input_shape[1]
  944. # only take the text part of the output representations
  945. sequence_output = outputs.last_hidden_state[:, :seq_length]
  946. sequence_output = self.dropout(sequence_output)
  947. logits = self.classifier(sequence_output)
  948. loss = None
  949. if labels is not None:
  950. loss_fct = CrossEntropyLoss()
  951. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  952. return TokenClassifierOutput(
  953. loss=loss,
  954. logits=logits,
  955. hidden_states=outputs.hidden_states,
  956. attentions=outputs.attentions,
  957. )
  958. @auto_docstring
  959. class LayoutLMv2ForQuestionAnswering(LayoutLMv2PreTrainedModel):
  960. def __init__(self, config, has_visual_segment_embedding=True):
  961. r"""
  962. has_visual_segment_embedding (`bool`, *optional*, defaults to `True`):
  963. Whether or not to add visual segment embeddings.
  964. """
  965. super().__init__(config)
  966. self.num_labels = config.num_labels
  967. config.has_visual_segment_embedding = has_visual_segment_embedding
  968. self.layoutlmv2 = LayoutLMv2Model(config)
  969. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  970. # Initialize weights and apply final processing
  971. self.post_init()
  972. def get_input_embeddings(self):
  973. return self.layoutlmv2.embeddings.word_embeddings
  974. @can_return_tuple
  975. @auto_docstring
  976. def forward(
  977. self,
  978. input_ids: torch.LongTensor | None = None,
  979. bbox: torch.LongTensor | None = None,
  980. image: torch.FloatTensor | None = None,
  981. attention_mask: torch.FloatTensor | None = None,
  982. token_type_ids: torch.LongTensor | None = None,
  983. position_ids: torch.LongTensor | None = None,
  984. inputs_embeds: torch.FloatTensor | None = None,
  985. start_positions: torch.LongTensor | None = None,
  986. end_positions: torch.LongTensor | None = None,
  987. **kwargs: Unpack[TransformersKwargs],
  988. ) -> tuple | QuestionAnsweringModelOutput:
  989. r"""
  990. input_ids (`torch.LongTensor` of shape `batch_size, sequence_length`):
  991. Indices of input sequence tokens in the vocabulary.
  992. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  993. [`PreTrainedTokenizer.__call__`] for details.
  994. [What are input IDs?](../glossary#input-ids)
  995. bbox (`torch.LongTensor` of shape `(batch_size, sequence_length, 4)`, *optional*):
  996. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  997. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  998. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  999. y1) represents the position of the lower right corner.
  1000. image (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `detectron.structures.ImageList` whose `tensors` is of shape `(batch_size, num_channels, height, width)`):
  1001. Batch of document images.
  1002. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  1003. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1004. 1]`:
  1005. - 0 corresponds to a *sentence A* token,
  1006. - 1 corresponds to a *sentence B* token.
  1007. [What are token type IDs?](../glossary#token-type-ids)
  1008. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  1009. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1010. config.max_position_embeddings - 1]`.
  1011. [What are position IDs?](../glossary#position-ids)
  1012. Example:
  1013. In this example below, we give the LayoutLMv2 model an image (of texts) and ask it a question. It will give us
  1014. a prediction of what it thinks the answer is (the span of the answer within the texts parsed from the image).
  1015. ```python
  1016. >>> from transformers import AutoProcessor, LayoutLMv2ForQuestionAnswering, set_seed
  1017. >>> import torch
  1018. >>> from PIL import Image
  1019. >>> from datasets import load_dataset
  1020. >>> set_seed(0)
  1021. >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv2-base-uncased")
  1022. >>> model = LayoutLMv2ForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
  1023. >>> dataset = load_dataset("hf-internal-testing/fixtures_docvqa")
  1024. >>> image = dataset["test"][0]["image"]
  1025. >>> question = "When is coffee break?"
  1026. >>> encoding = processor(image, question, return_tensors="pt")
  1027. >>> outputs = model(**encoding)
  1028. >>> predicted_start_idx = outputs.start_logits.argmax(-1).item()
  1029. >>> predicted_end_idx = outputs.end_logits.argmax(-1).item()
  1030. >>> predicted_start_idx, predicted_end_idx
  1031. (30, 191)
  1032. >>> predicted_answer_tokens = encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1]
  1033. >>> predicted_answer = processor.tokenizer.decode(predicted_answer_tokens)
  1034. >>> predicted_answer # results are not good without further fine-tuning
  1035. '44 a. m. to 12 : 25 p. m. 12 : 25 to 12 : 58 p. m. 12 : 58 to 4 : 00 p. m. 2 : 00 to 5 : 00 p. m. coffee break coffee will be served for men and women in the lobby adjacent to exhibit area. please move into exhibit area. ( exhibits open ) trrf general session ( part | ) presiding : lee a. waller trrf vice president " introductory remarks " lee a. waller, trrf vice presi - dent individual interviews with trrf public board members and sci - entific advisory council mem - bers conducted by trrf treasurer philip g. kuehn to get answers which the public refrigerated warehousing industry is looking for. plus questions from'
  1036. ```
  1037. ```python
  1038. >>> target_start_index = torch.tensor([7])
  1039. >>> target_end_index = torch.tensor([14])
  1040. >>> outputs = model(**encoding, start_positions=target_start_index, end_positions=target_end_index)
  1041. >>> predicted_answer_span_start = outputs.start_logits.argmax(-1).item()
  1042. >>> predicted_answer_span_end = outputs.end_logits.argmax(-1).item()
  1043. >>> predicted_answer_span_start, predicted_answer_span_end
  1044. (30, 191)
  1045. ```
  1046. """
  1047. outputs: BaseModelOutputWithPooling = self.layoutlmv2(
  1048. input_ids=input_ids,
  1049. bbox=bbox,
  1050. image=image,
  1051. attention_mask=attention_mask,
  1052. token_type_ids=token_type_ids,
  1053. position_ids=position_ids,
  1054. inputs_embeds=inputs_embeds,
  1055. **kwargs,
  1056. )
  1057. if input_ids is not None:
  1058. input_shape = input_ids.size()
  1059. else:
  1060. input_shape = inputs_embeds.size()[:-1]
  1061. seq_length = input_shape[1]
  1062. # only take the text part of the output representations
  1063. sequence_output = outputs.last_hidden_state[:, :seq_length]
  1064. logits = self.qa_outputs(sequence_output)
  1065. start_logits, end_logits = logits.split(1, dim=-1)
  1066. start_logits = start_logits.squeeze(-1).contiguous()
  1067. end_logits = end_logits.squeeze(-1).contiguous()
  1068. total_loss = None
  1069. if start_positions is not None and end_positions is not None:
  1070. # If we are on multi-GPU, split add a dimension
  1071. if len(start_positions.size()) > 1:
  1072. start_positions = start_positions.squeeze(-1)
  1073. if len(end_positions.size()) > 1:
  1074. end_positions = end_positions.squeeze(-1)
  1075. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1076. ignored_index = start_logits.size(1)
  1077. start_positions = start_positions.clamp(0, ignored_index)
  1078. end_positions = end_positions.clamp(0, ignored_index)
  1079. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1080. start_loss = loss_fct(start_logits, start_positions)
  1081. end_loss = loss_fct(end_logits, end_positions)
  1082. total_loss = (start_loss + end_loss) / 2
  1083. return QuestionAnsweringModelOutput(
  1084. loss=total_loss,
  1085. start_logits=start_logits,
  1086. end_logits=end_logits,
  1087. hidden_states=outputs.hidden_states,
  1088. attentions=outputs.attentions,
  1089. )
  1090. __all__ = [
  1091. "LayoutLMv2ForQuestionAnswering",
  1092. "LayoutLMv2ForSequenceClassification",
  1093. "LayoutLMv2ForTokenClassification",
  1094. "LayoutLMv2Layer",
  1095. "LayoutLMv2Model",
  1096. "LayoutLMv2PreTrainedModel",
  1097. ]