modeling_markuplm.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. # Copyright 2022 Microsoft Research Asia and the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch MarkupLM model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  19. from ... import initialization as init
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. MaskedLMOutput,
  26. QuestionAnsweringModelOutput,
  27. SequenceClassifierOutput,
  28. TokenClassifierOutput,
  29. )
  30. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  31. from ...processing_utils import Unpack
  32. from ...pytorch_utils import apply_chunking_to_forward
  33. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  34. from ...utils.generic import merge_with_config_defaults
  35. from ...utils.output_capturing import capture_outputs
  36. from .configuration_markuplm import MarkupLMConfig
  37. logger = logging.get_logger(__name__)
  38. class XPathEmbeddings(nn.Module):
  39. """Construct the embeddings from xpath tags and subscripts.
  40. We drop tree-id in this version, as its info can be covered by xpath.
  41. """
  42. def __init__(self, config):
  43. super().__init__()
  44. self.max_depth = config.max_depth
  45. self.xpath_unitseq2_embeddings = nn.Linear(config.xpath_unit_hidden_size * self.max_depth, config.hidden_size)
  46. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  47. self.activation = nn.ReLU()
  48. self.xpath_unitseq2_inner = nn.Linear(config.xpath_unit_hidden_size * self.max_depth, 4 * config.hidden_size)
  49. self.inner2emb = nn.Linear(4 * config.hidden_size, config.hidden_size)
  50. self.xpath_tag_sub_embeddings = nn.ModuleList(
  51. [
  52. nn.Embedding(config.max_xpath_tag_unit_embeddings, config.xpath_unit_hidden_size)
  53. for _ in range(self.max_depth)
  54. ]
  55. )
  56. self.xpath_subs_sub_embeddings = nn.ModuleList(
  57. [
  58. nn.Embedding(config.max_xpath_subs_unit_embeddings, config.xpath_unit_hidden_size)
  59. for _ in range(self.max_depth)
  60. ]
  61. )
  62. def forward(self, xpath_tags_seq=None, xpath_subs_seq=None):
  63. xpath_tags_embeddings = []
  64. xpath_subs_embeddings = []
  65. for i in range(self.max_depth):
  66. xpath_tags_embeddings.append(self.xpath_tag_sub_embeddings[i](xpath_tags_seq[:, :, i]))
  67. xpath_subs_embeddings.append(self.xpath_subs_sub_embeddings[i](xpath_subs_seq[:, :, i]))
  68. xpath_tags_embeddings = torch.cat(xpath_tags_embeddings, dim=-1)
  69. xpath_subs_embeddings = torch.cat(xpath_subs_embeddings, dim=-1)
  70. xpath_embeddings = xpath_tags_embeddings + xpath_subs_embeddings
  71. xpath_embeddings = self.inner2emb(self.dropout(self.activation(self.xpath_unitseq2_inner(xpath_embeddings))))
  72. return xpath_embeddings
  73. class MarkupLMEmbeddings(nn.Module):
  74. """Construct the embeddings from word, position and token_type embeddings."""
  75. def __init__(self, config):
  76. super().__init__()
  77. self.config = config
  78. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  79. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  80. self.max_depth = config.max_depth
  81. self.xpath_embeddings = XPathEmbeddings(config)
  82. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  83. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  84. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  85. self.register_buffer(
  86. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  87. )
  88. self.padding_idx = config.pad_token_id
  89. self.position_embeddings = nn.Embedding(
  90. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  91. )
  92. @staticmethod
  93. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings.create_position_ids_from_inputs_embeds
  94. def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
  95. """
  96. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  97. Args:
  98. inputs_embeds: torch.Tensor
  99. Returns: torch.Tensor
  100. """
  101. input_shape = inputs_embeds.size()[:-1]
  102. sequence_length = input_shape[1]
  103. position_ids = torch.arange(
  104. padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  105. )
  106. return position_ids.unsqueeze(0).expand(input_shape)
  107. @staticmethod
  108. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings.create_position_ids_from_input_ids
  109. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  110. """
  111. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  112. are ignored. This is modified from fairseq's `utils.make_positions`.
  113. Args:
  114. x: torch.Tensor x:
  115. Returns: torch.Tensor
  116. """
  117. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  118. mask = input_ids.ne(padding_idx).int()
  119. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  120. return incremental_indices.long() + padding_idx
  121. def forward(
  122. self,
  123. input_ids=None,
  124. xpath_tags_seq=None,
  125. xpath_subs_seq=None,
  126. token_type_ids=None,
  127. position_ids=None,
  128. inputs_embeds=None,
  129. ):
  130. if input_ids is not None:
  131. input_shape = input_ids.size()
  132. else:
  133. input_shape = inputs_embeds.size()[:-1]
  134. device = input_ids.device if input_ids is not None else inputs_embeds.device
  135. if position_ids is None:
  136. if input_ids is not None:
  137. # Create the position ids from the input token ids. Any padded tokens remain padded.
  138. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx)
  139. else:
  140. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
  141. if token_type_ids is None:
  142. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  143. if inputs_embeds is None:
  144. inputs_embeds = self.word_embeddings(input_ids)
  145. # prepare xpath seq
  146. if xpath_tags_seq is None:
  147. xpath_tags_seq = self.config.tag_pad_id * torch.ones(
  148. tuple(list(input_shape) + [self.max_depth]), dtype=torch.long, device=device
  149. )
  150. if xpath_subs_seq is None:
  151. xpath_subs_seq = self.config.subs_pad_id * torch.ones(
  152. tuple(list(input_shape) + [self.max_depth]), dtype=torch.long, device=device
  153. )
  154. words_embeddings = inputs_embeds
  155. position_embeddings = self.position_embeddings(position_ids)
  156. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  157. xpath_embeddings = self.xpath_embeddings(xpath_tags_seq, xpath_subs_seq)
  158. embeddings = words_embeddings + position_embeddings + token_type_embeddings + xpath_embeddings
  159. embeddings = self.LayerNorm(embeddings)
  160. embeddings = self.dropout(embeddings)
  161. return embeddings
  162. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->MarkupLM
  163. class MarkupLMSelfOutput(nn.Module):
  164. def __init__(self, config):
  165. super().__init__()
  166. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  167. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  168. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  169. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  170. hidden_states = self.dense(hidden_states)
  171. hidden_states = self.dropout(hidden_states)
  172. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  173. return hidden_states
  174. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  175. class MarkupLMIntermediate(nn.Module):
  176. def __init__(self, config):
  177. super().__init__()
  178. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  179. if isinstance(config.hidden_act, str):
  180. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  181. else:
  182. self.intermediate_act_fn = config.hidden_act
  183. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  184. hidden_states = self.dense(hidden_states)
  185. hidden_states = self.intermediate_act_fn(hidden_states)
  186. return hidden_states
  187. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->MarkupLM
  188. class MarkupLMOutput(nn.Module):
  189. def __init__(self, config):
  190. super().__init__()
  191. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  192. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  193. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  194. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  195. hidden_states = self.dense(hidden_states)
  196. hidden_states = self.dropout(hidden_states)
  197. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  198. return hidden_states
  199. # Copied from transformers.models.bert.modeling_bert.BertPooler
  200. class MarkupLMPooler(nn.Module):
  201. def __init__(self, config):
  202. super().__init__()
  203. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  204. self.activation = nn.Tanh()
  205. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  206. # We "pool" the model by simply taking the hidden state corresponding
  207. # to the first token.
  208. first_token_tensor = hidden_states[:, 0]
  209. pooled_output = self.dense(first_token_tensor)
  210. pooled_output = self.activation(pooled_output)
  211. return pooled_output
  212. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->MarkupLM
  213. class MarkupLMPredictionHeadTransform(nn.Module):
  214. def __init__(self, config):
  215. super().__init__()
  216. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  217. if isinstance(config.hidden_act, str):
  218. self.transform_act_fn = ACT2FN[config.hidden_act]
  219. else:
  220. self.transform_act_fn = config.hidden_act
  221. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  222. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  223. hidden_states = self.dense(hidden_states)
  224. hidden_states = self.transform_act_fn(hidden_states)
  225. hidden_states = self.LayerNorm(hidden_states)
  226. return hidden_states
  227. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->MarkupLM
  228. class MarkupLMLMPredictionHead(nn.Module):
  229. def __init__(self, config):
  230. super().__init__()
  231. self.transform = MarkupLMPredictionHeadTransform(config)
  232. # The output weights are the same as the input embeddings, but there is
  233. # an output-only bias for each token.
  234. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
  235. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  236. def forward(self, hidden_states):
  237. hidden_states = self.transform(hidden_states)
  238. hidden_states = self.decoder(hidden_states)
  239. return hidden_states
  240. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->MarkupLM
  241. class MarkupLMOnlyMLMHead(nn.Module):
  242. def __init__(self, config):
  243. super().__init__()
  244. self.predictions = MarkupLMLMPredictionHead(config)
  245. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  246. prediction_scores = self.predictions(sequence_output)
  247. return prediction_scores
  248. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  249. def eager_attention_forward(
  250. module: nn.Module,
  251. query: torch.Tensor,
  252. key: torch.Tensor,
  253. value: torch.Tensor,
  254. attention_mask: torch.Tensor | None,
  255. scaling: float,
  256. dropout: float = 0.0,
  257. **kwargs,
  258. ):
  259. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  260. if attention_mask is not None:
  261. attn_weights = attn_weights + attention_mask
  262. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  263. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  264. attn_output = torch.matmul(attn_weights, value)
  265. attn_output = attn_output.transpose(1, 2).contiguous()
  266. return attn_output, attn_weights
  267. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->MarkupLM
  268. class MarkupLMSelfAttention(nn.Module):
  269. def __init__(self, config):
  270. super().__init__()
  271. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  272. raise ValueError(
  273. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  274. f"heads ({config.num_attention_heads})"
  275. )
  276. self.config = config
  277. self.num_attention_heads = config.num_attention_heads
  278. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  279. self.all_head_size = self.num_attention_heads * self.attention_head_size
  280. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  281. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  282. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  283. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  284. self.attention_dropout = config.attention_probs_dropout_prob
  285. self.scaling = self.attention_head_size**-0.5
  286. def forward(
  287. self,
  288. hidden_states: torch.Tensor,
  289. attention_mask: torch.FloatTensor | None = None,
  290. **kwargs: Unpack[TransformersKwargs],
  291. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  292. input_shape = hidden_states.shape[:-1]
  293. hidden_shape = (*input_shape, -1, self.attention_head_size)
  294. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  295. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  296. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  297. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  298. self.config._attn_implementation, eager_attention_forward
  299. )
  300. attn_output, attn_weights = attention_interface(
  301. self,
  302. query_states,
  303. key_states,
  304. value_states,
  305. attention_mask,
  306. dropout=0.0 if not self.training else self.attention_dropout,
  307. scaling=self.scaling,
  308. **kwargs,
  309. )
  310. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  311. return attn_output, attn_weights
  312. # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->MarkupLM
  313. class MarkupLMAttention(nn.Module):
  314. def __init__(self, config):
  315. super().__init__()
  316. self.self = MarkupLMSelfAttention(config)
  317. self.output = MarkupLMSelfOutput(config)
  318. def forward(
  319. self,
  320. hidden_states: torch.Tensor,
  321. attention_mask: torch.FloatTensor | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> torch.Tensor:
  324. residual = hidden_states
  325. hidden_states, _ = self.self(
  326. hidden_states,
  327. attention_mask=attention_mask,
  328. **kwargs,
  329. )
  330. hidden_states = self.output(hidden_states, residual)
  331. return hidden_states
  332. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->MarkupLM
  333. class MarkupLMLayer(GradientCheckpointingLayer):
  334. def __init__(self, config):
  335. super().__init__()
  336. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  337. self.seq_len_dim = 1
  338. self.attention = MarkupLMAttention(config)
  339. self.intermediate = MarkupLMIntermediate(config)
  340. self.output = MarkupLMOutput(config)
  341. def forward(
  342. self,
  343. hidden_states: torch.Tensor,
  344. attention_mask: torch.FloatTensor | None = None,
  345. **kwargs: Unpack[TransformersKwargs],
  346. ) -> torch.Tensor:
  347. hidden_states = self.attention(
  348. hidden_states,
  349. attention_mask=attention_mask,
  350. **kwargs,
  351. )
  352. hidden_states = apply_chunking_to_forward(
  353. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, hidden_states
  354. )
  355. return hidden_states
  356. def feed_forward_chunk(self, attention_output):
  357. intermediate_output = self.intermediate(attention_output)
  358. layer_output = self.output(intermediate_output, attention_output)
  359. return layer_output
  360. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->MarkupLM
  361. class MarkupLMEncoder(nn.Module):
  362. def __init__(self, config):
  363. super().__init__()
  364. self.config = config
  365. self.layer = nn.ModuleList([MarkupLMLayer(config) for i in range(config.num_hidden_layers)])
  366. self.gradient_checkpointing = False
  367. def forward(
  368. self,
  369. hidden_states: torch.Tensor,
  370. attention_mask: torch.FloatTensor | None = None,
  371. **kwargs: Unpack[TransformersKwargs],
  372. ) -> BaseModelOutput:
  373. for layer_module in self.layer:
  374. hidden_states = layer_module(
  375. hidden_states,
  376. attention_mask,
  377. **kwargs,
  378. )
  379. return BaseModelOutput(
  380. last_hidden_state=hidden_states,
  381. )
  382. @auto_docstring
  383. class MarkupLMPreTrainedModel(PreTrainedModel):
  384. config: MarkupLMConfig
  385. base_model_prefix = "markuplm"
  386. _can_record_outputs = {
  387. "hidden_states": MarkupLMLayer,
  388. "attentions": MarkupLMSelfAttention,
  389. }
  390. @torch.no_grad()
  391. def _init_weights(self, module):
  392. """Initialize the weights"""
  393. super()._init_weights(module)
  394. if isinstance(module, MarkupLMLMPredictionHead):
  395. init.zeros_(module.bias)
  396. elif isinstance(module, MarkupLMEmbeddings):
  397. init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
  398. @auto_docstring
  399. class MarkupLMModel(MarkupLMPreTrainedModel):
  400. # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->MarkupLM
  401. def __init__(self, config, add_pooling_layer=True):
  402. r"""
  403. add_pooling_layer (bool, *optional*, defaults to `True`):
  404. Whether to add a pooling layer
  405. """
  406. super().__init__(config)
  407. self.config = config
  408. self.embeddings = MarkupLMEmbeddings(config)
  409. self.encoder = MarkupLMEncoder(config)
  410. self.pooler = MarkupLMPooler(config) if add_pooling_layer else None
  411. # Initialize weights and apply final processing
  412. self.post_init()
  413. def get_input_embeddings(self):
  414. return self.embeddings.word_embeddings
  415. def set_input_embeddings(self, value):
  416. self.embeddings.word_embeddings = value
  417. @merge_with_config_defaults
  418. @capture_outputs
  419. @auto_docstring
  420. def forward(
  421. self,
  422. input_ids: torch.LongTensor | None = None,
  423. xpath_tags_seq: torch.LongTensor | None = None,
  424. xpath_subs_seq: torch.LongTensor | None = None,
  425. attention_mask: torch.FloatTensor | None = None,
  426. token_type_ids: torch.LongTensor | None = None,
  427. position_ids: torch.LongTensor | None = None,
  428. inputs_embeds: torch.FloatTensor | None = None,
  429. **kwargs: Unpack[TransformersKwargs],
  430. ) -> tuple | BaseModelOutputWithPooling:
  431. r"""
  432. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  433. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  434. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  435. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  436. Examples:
  437. ```python
  438. >>> from transformers import AutoProcessor, MarkupLMModel
  439. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
  440. >>> model = MarkupLMModel.from_pretrained("microsoft/markuplm-base")
  441. >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
  442. >>> encoding = processor(html_string, return_tensors="pt")
  443. >>> outputs = model(**encoding)
  444. >>> last_hidden_states = outputs.last_hidden_state
  445. >>> list(last_hidden_states.shape)
  446. [1, 4, 768]
  447. ```"""
  448. if input_ids is not None and inputs_embeds is not None:
  449. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  450. elif input_ids is not None:
  451. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  452. input_shape = input_ids.size()
  453. elif inputs_embeds is not None:
  454. input_shape = inputs_embeds.size()[:-1]
  455. else:
  456. raise ValueError("You have to specify either input_ids or inputs_embeds")
  457. device = input_ids.device if input_ids is not None else inputs_embeds.device
  458. if attention_mask is None:
  459. attention_mask = torch.ones(input_shape, device=device)
  460. if token_type_ids is None:
  461. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  462. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  463. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)
  464. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  465. embedding_output = self.embeddings(
  466. input_ids=input_ids,
  467. xpath_tags_seq=xpath_tags_seq,
  468. xpath_subs_seq=xpath_subs_seq,
  469. position_ids=position_ids,
  470. token_type_ids=token_type_ids,
  471. inputs_embeds=inputs_embeds,
  472. )
  473. encoder_outputs = self.encoder(
  474. embedding_output,
  475. extended_attention_mask,
  476. **kwargs,
  477. )
  478. sequence_output = encoder_outputs[0]
  479. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  480. return BaseModelOutputWithPooling(
  481. last_hidden_state=sequence_output,
  482. pooler_output=pooled_output,
  483. )
  484. @auto_docstring
  485. class MarkupLMForQuestionAnswering(MarkupLMPreTrainedModel):
  486. # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering.__init__ with bert->markuplm, Bert->MarkupLM
  487. def __init__(self, config):
  488. super().__init__(config)
  489. self.num_labels = config.num_labels
  490. self.markuplm = MarkupLMModel(config, add_pooling_layer=False)
  491. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  492. # Initialize weights and apply final processing
  493. self.post_init()
  494. @can_return_tuple
  495. @auto_docstring
  496. def forward(
  497. self,
  498. input_ids: torch.Tensor | None = None,
  499. xpath_tags_seq: torch.Tensor | None = None,
  500. xpath_subs_seq: torch.Tensor | None = None,
  501. attention_mask: torch.Tensor | None = None,
  502. token_type_ids: torch.Tensor | None = None,
  503. position_ids: torch.Tensor | None = None,
  504. inputs_embeds: torch.Tensor | None = None,
  505. start_positions: torch.Tensor | None = None,
  506. end_positions: torch.Tensor | None = None,
  507. **kwargs: Unpack[TransformersKwargs],
  508. ) -> tuple[torch.Tensor] | QuestionAnsweringModelOutput:
  509. r"""
  510. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  511. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  512. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  513. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  514. Examples:
  515. ```python
  516. >>> from transformers import AutoProcessor, MarkupLMForQuestionAnswering
  517. >>> import torch
  518. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
  519. >>> model = MarkupLMForQuestionAnswering.from_pretrained("microsoft/markuplm-base-finetuned-websrc")
  520. >>> html_string = "<html> <head> <title>My name is Niels</title> </head> </html>"
  521. >>> question = "What's his name?"
  522. >>> encoding = processor(html_string, questions=question, return_tensors="pt")
  523. >>> with torch.no_grad():
  524. ... outputs = model(**encoding)
  525. >>> answer_start_index = outputs.start_logits.argmax()
  526. >>> answer_end_index = outputs.end_logits.argmax()
  527. >>> predict_answer_tokens = encoding.input_ids[0, answer_start_index : answer_end_index + 1]
  528. >>> processor.decode(predict_answer_tokens).strip()
  529. 'Niels'
  530. ```"""
  531. outputs = self.markuplm(
  532. input_ids,
  533. xpath_tags_seq=xpath_tags_seq,
  534. xpath_subs_seq=xpath_subs_seq,
  535. attention_mask=attention_mask,
  536. token_type_ids=token_type_ids,
  537. position_ids=position_ids,
  538. inputs_embeds=inputs_embeds,
  539. **kwargs,
  540. )
  541. sequence_output = outputs[0]
  542. logits = self.qa_outputs(sequence_output)
  543. start_logits, end_logits = logits.split(1, dim=-1)
  544. start_logits = start_logits.squeeze(-1).contiguous()
  545. end_logits = end_logits.squeeze(-1).contiguous()
  546. total_loss = None
  547. if start_positions is not None and end_positions is not None:
  548. # If we are on multi-GPU, split add a dimension
  549. if len(start_positions.size()) > 1:
  550. start_positions = start_positions.squeeze(-1)
  551. if len(end_positions.size()) > 1:
  552. end_positions = end_positions.squeeze(-1)
  553. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  554. ignored_index = start_logits.size(1)
  555. start_positions.clamp_(0, ignored_index)
  556. end_positions.clamp_(0, ignored_index)
  557. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  558. start_loss = loss_fct(start_logits, start_positions)
  559. end_loss = loss_fct(end_logits, end_positions)
  560. total_loss = (start_loss + end_loss) / 2
  561. return QuestionAnsweringModelOutput(
  562. loss=total_loss,
  563. start_logits=start_logits,
  564. end_logits=end_logits,
  565. hidden_states=outputs.hidden_states,
  566. attentions=outputs.attentions,
  567. )
  568. @auto_docstring(
  569. custom_intro="""
  570. MarkupLM Model with a `token_classification` head on top.
  571. """
  572. )
  573. class MarkupLMForTokenClassification(MarkupLMPreTrainedModel):
  574. # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification.__init__ with bert->markuplm, Bert->MarkupLM
  575. def __init__(self, config):
  576. super().__init__(config)
  577. self.num_labels = config.num_labels
  578. self.markuplm = MarkupLMModel(config, add_pooling_layer=False)
  579. classifier_dropout = (
  580. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  581. )
  582. self.dropout = nn.Dropout(classifier_dropout)
  583. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  584. # Initialize weights and apply final processing
  585. self.post_init()
  586. @can_return_tuple
  587. @auto_docstring
  588. def forward(
  589. self,
  590. input_ids: torch.Tensor | None = None,
  591. xpath_tags_seq: torch.Tensor | None = None,
  592. xpath_subs_seq: torch.Tensor | None = None,
  593. attention_mask: torch.Tensor | None = None,
  594. token_type_ids: torch.Tensor | None = None,
  595. position_ids: torch.Tensor | None = None,
  596. inputs_embeds: torch.Tensor | None = None,
  597. labels: torch.Tensor | None = None,
  598. **kwargs: Unpack[TransformersKwargs],
  599. ) -> tuple[torch.Tensor] | MaskedLMOutput:
  600. r"""
  601. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  602. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  603. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  604. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  605. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  606. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  607. Examples:
  608. ```python
  609. >>> from transformers import AutoProcessor, AutoModelForTokenClassification
  610. >>> import torch
  611. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
  612. >>> processor.parse_html = False
  613. >>> model = AutoModelForTokenClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)
  614. >>> nodes = ["hello", "world"]
  615. >>> xpaths = ["/html/body/div/li[1]/div/span", "/html/body/div/li[1]/div/span"]
  616. >>> node_labels = [1, 2]
  617. >>> encoding = processor(nodes=nodes, xpaths=xpaths, node_labels=node_labels, return_tensors="pt")
  618. >>> with torch.no_grad():
  619. ... outputs = model(**encoding)
  620. >>> loss = outputs.loss
  621. >>> logits = outputs.logits
  622. ```"""
  623. outputs = self.markuplm(
  624. input_ids,
  625. xpath_tags_seq=xpath_tags_seq,
  626. xpath_subs_seq=xpath_subs_seq,
  627. attention_mask=attention_mask,
  628. token_type_ids=token_type_ids,
  629. position_ids=position_ids,
  630. inputs_embeds=inputs_embeds,
  631. **kwargs,
  632. )
  633. sequence_output = outputs[0]
  634. prediction_scores = self.classifier(sequence_output) # (batch_size, seq_length, node_type_size)
  635. loss = None
  636. if labels is not None:
  637. loss_fct = CrossEntropyLoss()
  638. loss = loss_fct(
  639. prediction_scores.view(-1, self.config.num_labels),
  640. labels.view(-1),
  641. )
  642. return TokenClassifierOutput(
  643. loss=loss,
  644. logits=prediction_scores,
  645. hidden_states=outputs.hidden_states,
  646. attentions=outputs.attentions,
  647. )
  648. @auto_docstring(
  649. custom_intro="""
  650. MarkupLM Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  651. pooled output) e.g. for GLUE tasks.
  652. """
  653. )
  654. class MarkupLMForSequenceClassification(MarkupLMPreTrainedModel):
  655. # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification.__init__ with bert->markuplm, Bert->MarkupLM
  656. def __init__(self, config):
  657. super().__init__(config)
  658. self.num_labels = config.num_labels
  659. self.config = config
  660. self.markuplm = MarkupLMModel(config)
  661. classifier_dropout = (
  662. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  663. )
  664. self.dropout = nn.Dropout(classifier_dropout)
  665. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  666. # Initialize weights and apply final processing
  667. self.post_init()
  668. @can_return_tuple
  669. @auto_docstring
  670. def forward(
  671. self,
  672. input_ids: torch.Tensor | None = None,
  673. xpath_tags_seq: torch.Tensor | None = None,
  674. xpath_subs_seq: torch.Tensor | None = None,
  675. attention_mask: torch.Tensor | None = None,
  676. token_type_ids: torch.Tensor | None = None,
  677. position_ids: torch.Tensor | None = None,
  678. inputs_embeds: torch.Tensor | None = None,
  679. labels: torch.Tensor | None = None,
  680. **kwargs: Unpack[TransformersKwargs],
  681. ) -> tuple[torch.Tensor] | SequenceClassifierOutput:
  682. r"""
  683. xpath_tags_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  684. Tag IDs for each token in the input sequence, padded up to config.max_depth.
  685. xpath_subs_seq (`torch.LongTensor` of shape `(batch_size, sequence_length, config.max_depth)`, *optional*):
  686. Subscript IDs for each token in the input sequence, padded up to config.max_depth.
  687. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  688. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  689. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  690. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  691. Examples:
  692. ```python
  693. >>> from transformers import AutoProcessor, AutoModelForSequenceClassification
  694. >>> import torch
  695. >>> processor = AutoProcessor.from_pretrained("microsoft/markuplm-base")
  696. >>> model = AutoModelForSequenceClassification.from_pretrained("microsoft/markuplm-base", num_labels=7)
  697. >>> html_string = "<html> <head> <title>Page Title</title> </head> </html>"
  698. >>> encoding = processor(html_string, return_tensors="pt")
  699. >>> with torch.no_grad():
  700. ... outputs = model(**encoding)
  701. >>> loss = outputs.loss
  702. >>> logits = outputs.logits
  703. ```"""
  704. outputs = self.markuplm(
  705. input_ids,
  706. xpath_tags_seq=xpath_tags_seq,
  707. xpath_subs_seq=xpath_subs_seq,
  708. attention_mask=attention_mask,
  709. token_type_ids=token_type_ids,
  710. position_ids=position_ids,
  711. inputs_embeds=inputs_embeds,
  712. **kwargs,
  713. )
  714. pooled_output = outputs[1]
  715. pooled_output = self.dropout(pooled_output)
  716. logits = self.classifier(pooled_output)
  717. loss = None
  718. if labels is not None:
  719. if self.config.problem_type is None:
  720. if self.num_labels == 1:
  721. self.config.problem_type = "regression"
  722. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  723. self.config.problem_type = "single_label_classification"
  724. else:
  725. self.config.problem_type = "multi_label_classification"
  726. if self.config.problem_type == "regression":
  727. loss_fct = MSELoss()
  728. if self.num_labels == 1:
  729. loss = loss_fct(logits.squeeze(), labels.squeeze())
  730. else:
  731. loss = loss_fct(logits, labels)
  732. elif self.config.problem_type == "single_label_classification":
  733. loss_fct = CrossEntropyLoss()
  734. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  735. elif self.config.problem_type == "multi_label_classification":
  736. loss_fct = BCEWithLogitsLoss()
  737. loss = loss_fct(logits, labels)
  738. return SequenceClassifierOutput(
  739. loss=loss,
  740. logits=logits,
  741. hidden_states=outputs.hidden_states,
  742. attentions=outputs.attentions,
  743. )
  744. __all__ = [
  745. "MarkupLMForQuestionAnswering",
  746. "MarkupLMForSequenceClassification",
  747. "MarkupLMForTokenClassification",
  748. "MarkupLMModel",
  749. "MarkupLMPreTrainedModel",
  750. ]