data_collator.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import multiprocessing as mp
  15. import warnings
  16. from collections.abc import Callable, Mapping
  17. from dataclasses import dataclass
  18. from random import randint
  19. from typing import Any
  20. import numpy as np
  21. from ..tokenization_utils_base import PreTrainedTokenizerBase
  22. from ..utils import PaddingStrategy
  23. InputDataClass = Any
  24. """
  25. A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
  26. of PyTorch tensors or NumPy arrays.
  27. """
  28. DataCollator = Callable[[list[InputDataClass]], dict[str, Any]]
  29. class DataCollatorMixin:
  30. def __call__(self, features, return_tensors: str | None = None):
  31. if return_tensors is None:
  32. return_tensors = self.return_tensors
  33. if return_tensors == "pt":
  34. return self.torch_call(features)
  35. elif return_tensors == "np":
  36. return self.numpy_call(features)
  37. else:
  38. raise ValueError(f"Framework '{return_tensors}' not recognized!")
  39. def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
  40. """
  41. Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
  42. """
  43. # To avoid errors when using Feature extractors
  44. if not hasattr(tokenizer, "deprecation_warnings"):
  45. return tokenizer.pad(*pad_args, **pad_kwargs)
  46. # Save the state of the warning, then disable it
  47. warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
  48. tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
  49. try:
  50. padded = tokenizer.pad(*pad_args, **pad_kwargs)
  51. finally:
  52. # Restore the state of the warning.
  53. tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
  54. return padded
  55. def default_data_collator(features: list[InputDataClass], return_tensors="pt") -> dict[str, Any]:
  56. """
  57. Very simple data collator that simply collates batches of dict-like objects and performs special handling for
  58. potential keys named:
  59. - `label`: handles a single value (int or float) per object
  60. - `label_ids`: handles a list of values per object
  61. Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
  62. to the model. See glue and ner for example of how it's useful.
  63. """
  64. # In this function we'll make the assumption that all `features` in the batch
  65. # have the same attributes.
  66. # So we will look at the first element as a proxy for what attributes exist
  67. # on the whole batch.
  68. if return_tensors == "pt":
  69. return torch_default_data_collator(features)
  70. elif return_tensors == "np":
  71. return numpy_default_data_collator(features)
  72. @dataclass
  73. class DefaultDataCollator(DataCollatorMixin):
  74. """
  75. Very simple data collator that simply collates batches of dict-like objects and performs special handling for
  76. potential keys named:
  77. - `label`: handles a single value (int or float) per object
  78. - `label_ids`: handles a list of values per object
  79. Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
  80. to the model. See glue and ner for example of how it's useful.
  81. This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
  82. helpful if you need to set a return_tensors value at initialization.
  83. Args:
  84. return_tensors (`str`, *optional*, defaults to `"pt"`):
  85. The type of Tensor to return. Allowable values are "np", or "pt".
  86. """
  87. return_tensors: str = "pt"
  88. def __call__(self, features: list[dict[str, Any]], return_tensors=None) -> dict[str, Any]:
  89. if return_tensors is None:
  90. return_tensors = self.return_tensors
  91. return default_data_collator(features, return_tensors)
  92. def torch_default_data_collator(features: list[InputDataClass]) -> dict[str, Any]:
  93. import torch
  94. if not isinstance(features[0], Mapping):
  95. features = [vars(f) for f in features]
  96. first = features[0]
  97. batch = {}
  98. # Special handling for labels.
  99. # Ensure that tensor is created with the correct type
  100. # (it should be automatically the case, but let's make sure of it.)
  101. if "label" in first and first["label"] is not None:
  102. label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
  103. dtype = torch.long if isinstance(label, int) else torch.float
  104. batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
  105. elif "label_ids" in first and first["label_ids"] is not None:
  106. if isinstance(first["label_ids"], torch.Tensor):
  107. batch["labels"] = torch.stack([f["label_ids"] for f in features])
  108. else:
  109. dtype = torch.long if isinstance(first["label_ids"][0], int) else torch.float
  110. batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
  111. # Handling of all other possible keys.
  112. # Again, we will use the first element to figure out which key/values are not None for this model.
  113. for k, v in first.items():
  114. if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
  115. if isinstance(v, torch.Tensor):
  116. batch[k] = torch.stack([f[k] for f in features])
  117. elif isinstance(v, np.ndarray):
  118. batch[k] = torch.from_numpy(np.stack([f[k] for f in features]))
  119. else:
  120. batch[k] = torch.tensor([f[k] for f in features])
  121. return batch
  122. def numpy_default_data_collator(features: list[InputDataClass]) -> dict[str, Any]:
  123. if not isinstance(features[0], Mapping):
  124. features = [vars(f) for f in features]
  125. first = features[0]
  126. batch = {}
  127. # Special handling for labels.
  128. # Ensure that tensor is created with the correct type
  129. # (it should be automatically the case, but let's make sure of it.)
  130. if "label" in first and first["label"] is not None:
  131. label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
  132. dtype = np.int64 if isinstance(label, int) else np.float32
  133. batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
  134. elif "label_ids" in first and first["label_ids"] is not None:
  135. if isinstance(first["label_ids"], np.ndarray):
  136. batch["labels"] = np.stack([f["label_ids"] for f in features])
  137. else:
  138. dtype = np.int64 if isinstance(first["label_ids"][0], int) else np.float32
  139. batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
  140. # Handling of all other possible keys.
  141. # Again, we will use the first element to figure out which key/values are not None for this model.
  142. for k, v in first.items():
  143. if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
  144. if isinstance(v, np.ndarray):
  145. batch[k] = np.stack([f[k] for f in features])
  146. else:
  147. batch[k] = np.array([f[k] for f in features])
  148. return batch
  149. @dataclass
  150. class DataCollatorWithPadding:
  151. """
  152. Data collator that will dynamically pad the inputs received.
  153. Args:
  154. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  155. The tokenizer used for encoding the data.
  156. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  157. Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
  158. among:
  159. - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
  160. sequence is provided).
  161. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  162. acceptable input length for the model if that argument is not provided.
  163. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
  164. max_length (`int`, *optional*):
  165. Maximum length of the returned list and optionally padding length (see above).
  166. pad_to_multiple_of (`int`, *optional*):
  167. If set will pad the sequence to a multiple of the provided value.
  168. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  169. 7.0 (Volta).
  170. return_tensors (`str`, *optional*, defaults to `"pt"`):
  171. The type of Tensor to return. Allowable values are "np", or "pt".
  172. """
  173. tokenizer: PreTrainedTokenizerBase
  174. padding: bool | str | PaddingStrategy = True
  175. max_length: int | None = None
  176. pad_to_multiple_of: int | None = None
  177. return_tensors: str = "pt"
  178. def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
  179. batch = pad_without_fast_tokenizer_warning(
  180. self.tokenizer,
  181. features,
  182. padding=self.padding,
  183. max_length=self.max_length,
  184. pad_to_multiple_of=self.pad_to_multiple_of,
  185. return_tensors=self.return_tensors,
  186. )
  187. if "label" in batch:
  188. batch["labels"] = batch["label"]
  189. del batch["label"]
  190. if "label_ids" in batch:
  191. batch["labels"] = batch["label_ids"]
  192. del batch["label_ids"]
  193. return batch
  194. @dataclass
  195. class DataCollatorForTokenClassification(DataCollatorMixin):
  196. """
  197. Data collator that will dynamically pad the inputs received, as well as the labels.
  198. Args:
  199. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  200. The tokenizer used for encoding the data.
  201. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  202. Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
  203. among:
  204. - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
  205. sequence is provided).
  206. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  207. acceptable input length for the model if that argument is not provided.
  208. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
  209. max_length (`int`, *optional*):
  210. Maximum length of the returned list and optionally padding length (see above).
  211. pad_to_multiple_of (`int`, *optional*):
  212. If set will pad the sequence to a multiple of the provided value.
  213. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  214. 7.0 (Volta).
  215. label_pad_token_id (`int`, *optional*, defaults to -100):
  216. The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
  217. return_tensors (`str`, *optional*, defaults to `"pt"`):
  218. The type of Tensor to return. Allowable values are "np", or "pt".
  219. """
  220. tokenizer: PreTrainedTokenizerBase
  221. padding: bool | str | PaddingStrategy = True
  222. max_length: int | None = None
  223. pad_to_multiple_of: int | None = None
  224. label_pad_token_id: int = -100
  225. return_tensors: str = "pt"
  226. def torch_call(self, features):
  227. import torch
  228. label_name = "label" if "label" in features[0] else "labels"
  229. labels = [feature[label_name] for feature in features] if label_name in features[0] else None
  230. no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
  231. batch = pad_without_fast_tokenizer_warning(
  232. self.tokenizer,
  233. no_labels_features,
  234. padding=self.padding,
  235. max_length=self.max_length,
  236. pad_to_multiple_of=self.pad_to_multiple_of,
  237. return_tensors="pt",
  238. )
  239. if labels is None:
  240. return batch
  241. sequence_length = batch["input_ids"].shape[1]
  242. padding_side = self.tokenizer.padding_side
  243. def to_list(tensor_or_iterable):
  244. if isinstance(tensor_or_iterable, torch.Tensor):
  245. return tensor_or_iterable.tolist()
  246. return list(tensor_or_iterable)
  247. if padding_side == "right":
  248. batch[label_name] = [
  249. to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
  250. ]
  251. else:
  252. batch[label_name] = [
  253. [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
  254. ]
  255. batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
  256. return batch
  257. def numpy_call(self, features):
  258. label_name = "label" if "label" in features[0] else "labels"
  259. labels = [feature[label_name] for feature in features] if label_name in features[0] else None
  260. batch = pad_without_fast_tokenizer_warning(
  261. self.tokenizer,
  262. features,
  263. padding=self.padding,
  264. max_length=self.max_length,
  265. pad_to_multiple_of=self.pad_to_multiple_of,
  266. # Conversion to tensors will fail if we have labels as they are not of the same length yet.
  267. return_tensors="np" if labels is None else None,
  268. )
  269. if labels is None:
  270. return batch
  271. sequence_length = np.array(batch["input_ids"]).shape[1]
  272. padding_side = self.tokenizer.padding_side
  273. if padding_side == "right":
  274. batch["labels"] = [
  275. list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
  276. ]
  277. else:
  278. batch["labels"] = [
  279. [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
  280. ]
  281. batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
  282. return batch
  283. def _torch_collate_batch(examples, tokenizer, pad_to_multiple_of: int | None = None):
  284. """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
  285. import torch
  286. # Tensorize if necessary.
  287. if isinstance(examples[0], (list, tuple, np.ndarray)):
  288. examples = [torch.tensor(e, dtype=torch.long) for e in examples]
  289. length_of_first = examples[0].size(0)
  290. # Check if padding is necessary.
  291. are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
  292. if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
  293. if not isinstance(examples, torch.Tensor):
  294. return torch.stack(examples, dim=0)
  295. # If yes, check if we have a `pad_token`.
  296. if tokenizer.pad_token is None:
  297. raise ValueError(
  298. "You are attempting to pad samples but the tokenizer you are using"
  299. f" ({tokenizer.__class__.__name__}) does not have a pad token."
  300. )
  301. # Creating the full tensor and filling it with our data.
  302. max_length = max(x.size(0) for x in examples)
  303. if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  304. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  305. result = examples[0].new_full([len(examples), max_length], tokenizer.pad_token_id)
  306. for i, example in enumerate(examples):
  307. if tokenizer.padding_side == "right":
  308. result[i, : example.shape[0]] = example
  309. else:
  310. result[i, -example.shape[0] :] = example
  311. return result
  312. def _numpy_collate_batch(examples, tokenizer, pad_to_multiple_of: int | None = None):
  313. """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
  314. # Tensorize if necessary.
  315. if isinstance(examples[0], (list, tuple)):
  316. examples = [np.array(e, dtype=np.int64) for e in examples]
  317. # Check if padding is necessary.
  318. length_of_first = len(examples[0])
  319. are_tensors_same_length = all(len(x) == length_of_first for x in examples)
  320. if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
  321. return np.stack(examples, axis=0)
  322. # If yes, check if we have a `pad_token`.
  323. if tokenizer.pad_token is None:
  324. raise ValueError(
  325. "You are attempting to pad samples but the tokenizer you are using"
  326. f" ({tokenizer.__class__.__name__}) does not have a pad token."
  327. )
  328. # Creating the full tensor and filling it with our data.
  329. max_length = max(len(x) for x in examples)
  330. if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
  331. max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  332. result = np.full(shape=(len(examples), max_length), fill_value=tokenizer.pad_token_id, dtype=examples[0].dtype)
  333. for i, example in enumerate(examples):
  334. if tokenizer.padding_side == "right":
  335. result[i, : example.shape[0]] = example
  336. else:
  337. result[i, -example.shape[0] :] = example
  338. return result
  339. @dataclass
  340. class DataCollatorForMultipleChoice(DataCollatorMixin):
  341. """
  342. Data collator that dynamically pads a batch of nested examples for multiple choice, so that all choices
  343. of all examples have the same length.
  344. Args:
  345. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  346. The tokenizer used for encoding the data.
  347. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  348. Select a strategy to pad the returned sequences according to the model's padding side and padding index
  349. among:
  350. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
  351. is provided).
  352. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  353. acceptable input length for the model if that argument is not provided.
  354. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  355. lengths).
  356. max_length (`int`, *optional*):
  357. Maximum length of the returned list and optionally padding length (see above).
  358. pad_to_multiple_of (`int`, *optional*):
  359. Pad the sequence to a multiple of the provided value.
  360. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  361. 7.5 (Volta).
  362. return_tensors (`str`, *optional*, defaults to `"pt"`):
  363. The type of Tensor to return. Allowable values are "np", or "pt".
  364. """
  365. tokenizer: PreTrainedTokenizerBase
  366. padding: bool | str | PaddingStrategy = True
  367. max_length: int | None = None
  368. pad_to_multiple_of: int | None = None
  369. return_tensors: str = "pt"
  370. def torch_call(self, examples: list[dict[str, Any]]): # Refactored implementation from the docs.
  371. import torch
  372. # Take labels out of the examples beforehand, because they aren't nested.
  373. label_name = "label" if "label" in examples[0] else "labels"
  374. labels = [example.pop(label_name) for example in examples]
  375. batch_size = len(examples)
  376. num_choices = len(examples[0]["input_ids"])
  377. # Go from e.g. 2 examples of 2 choices [{input_ids: [[1], [2]]}, {input_ids: [[3], [4]]}]
  378. # to 4 examples [{input_ids: [1]}, {input_ids: [2]}] + [{input_ids: [3]}, {input_ids: [4]}]
  379. flat_examples = sum(
  380. ([{k: v[i] for k, v in example.items()} for i in range(num_choices)] for example in examples), start=[]
  381. )
  382. # Pad all choices of all examples as if you're padding any other batch of examples.
  383. batch = self.tokenizer.pad(
  384. flat_examples,
  385. padding=self.padding,
  386. max_length=self.max_length,
  387. pad_to_multiple_of=self.pad_to_multiple_of,
  388. return_tensors="pt",
  389. )
  390. # Reshape from B*C x L into B x C x L, and add the labels back in.
  391. batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
  392. batch["labels"] = torch.tensor(labels, dtype=torch.int64)
  393. return batch
  394. @dataclass
  395. class DataCollatorForSeq2Seq:
  396. """
  397. Data collator that will dynamically pad the inputs received, as well as the labels.
  398. Args:
  399. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  400. The tokenizer used for encoding the data.
  401. model ([`PreTrainedModel`], *optional*):
  402. The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
  403. prepare the *decoder_input_ids*
  404. This is useful when using *label_smoothing* to avoid calculating loss twice.
  405. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
  406. Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
  407. among:
  408. - `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
  409. sequence is provided).
  410. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  411. acceptable input length for the model if that argument is not provided.
  412. - `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
  413. max_length (`int`, *optional*):
  414. Maximum length of the returned list and optionally padding length (see above).
  415. pad_to_multiple_of (`int`, *optional*):
  416. If set will pad the sequence to a multiple of the provided value.
  417. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
  418. 7.0 (Volta).
  419. label_pad_token_id (`int`, *optional*, defaults to -100):
  420. The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
  421. return_tensors (`str`, *optional*, defaults to `"pt"`):
  422. The type of Tensor to return. Allowable values are "np", or "pt".
  423. """
  424. tokenizer: PreTrainedTokenizerBase
  425. model: Any | None = None
  426. padding: bool | str | PaddingStrategy = True
  427. max_length: int | None = None
  428. pad_to_multiple_of: int | None = None
  429. label_pad_token_id: int = -100
  430. return_tensors: str = "pt"
  431. def __call__(self, features, return_tensors=None):
  432. if return_tensors is None:
  433. return_tensors = self.return_tensors
  434. label_name = "label" if "label" in features[0] else "labels"
  435. labels = [feature[label_name] for feature in features] if label_name in features[0] else None
  436. # reconvert list[None] to None if necessary
  437. # this might occur when we pass {..., "labels": None}
  438. if labels is not None and all(label is None for label in labels):
  439. labels = None
  440. non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
  441. # run through tokenizer without labels to ensure no side effects
  442. batch = pad_without_fast_tokenizer_warning(
  443. self.tokenizer,
  444. non_labels_features,
  445. padding=self.padding,
  446. max_length=self.max_length,
  447. pad_to_multiple_of=self.pad_to_multiple_of,
  448. return_tensors=return_tensors,
  449. )
  450. # we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
  451. no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
  452. if labels is not None:
  453. if no_padding:
  454. if isinstance(features[0][label_name], list):
  455. batch["labels"] = list(labels)
  456. else:
  457. batch["labels"] = [np.concatenate([label, []]) for label in labels]
  458. else:
  459. max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
  460. max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
  461. if self.pad_to_multiple_of is not None:
  462. max_label_length = (
  463. (max_label_length + self.pad_to_multiple_of - 1)
  464. // self.pad_to_multiple_of
  465. * self.pad_to_multiple_of
  466. )
  467. padding_side = self.tokenizer.padding_side
  468. if isinstance(features[0][label_name], list):
  469. batch["labels"] = [
  470. label + [self.label_pad_token_id] * (max_label_length - len(label))
  471. if padding_side == "right"
  472. else [self.label_pad_token_id] * (max_label_length - len(label)) + label
  473. for label in labels
  474. ]
  475. else:
  476. batch["labels"] = [
  477. np.concatenate(
  478. [
  479. label,
  480. np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
  481. ]
  482. )
  483. if padding_side == "right"
  484. else np.concatenate(
  485. [
  486. np.array([self.label_pad_token_id] * (max_label_length - len(label)), dtype=np.int64),
  487. label,
  488. ]
  489. )
  490. for label in labels
  491. ]
  492. # reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
  493. if batch.get("labels", None) is not None:
  494. if return_tensors == "pt":
  495. import torch
  496. batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
  497. else:
  498. batch["labels"] = np.array(batch["labels"], dtype=np.int64)
  499. else:
  500. batch["labels"] = None
  501. # prepare decoder_input_ids
  502. if (
  503. labels is not None
  504. and self.model is not None
  505. and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
  506. ):
  507. decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
  508. batch["decoder_input_ids"] = decoder_input_ids
  509. return batch
  510. @dataclass
  511. class DataCollatorForLanguageModeling(DataCollatorMixin):
  512. """
  513. Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
  514. are not all of the same length.
  515. Args:
  516. tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
  517. The tokenizer used for encoding the data.
  518. mlm (`bool`, *optional*, defaults to `True`):
  519. Whether or not to use masked language modeling. If set to `False`, the labels are the same as the inputs
  520. with the padding tokens ignored (by setting them to -100). Otherwise, the labels are -100 for non-masked
  521. tokens and the value to predict for the masked token.
  522. whole_word_mask (`bool`, *optional*, defaults to `False`):
  523. Whether or not to mask whole words instead of individual tokens.
  524. mlm_probability (`float`, *optional*, defaults to 0.15):
  525. The probability with which to (randomly) mask tokens in the input, when `mlm` is set to `True`.
  526. mask_replace_prob (`float`, *optional*, defaults to 0.8):
  527. The probability with which masked tokens are replaced by the tokenizer's mask token (e.g., `[MASK]`).
  528. Defaults to 0.8, meaning 80% of the masked tokens will be replaced with `[MASK]`.
  529. Only works when `mlm` is set to `True`.
  530. random_replace_prob (`float`, *optional*, defaults to 0.1):
  531. The probability with which masked tokens are replaced by random tokens from the tokenizer's vocabulary.
  532. Defaults to 0.1, meaning 10% of the masked tokens will be replaced with random tokens. The remaining
  533. masked tokens (1 - mask_replace_prob - random_replace_prob) are left unchanged.
  534. Only works when `mlm` is set to `True`.
  535. pad_to_multiple_of (`int`, *optional*):
  536. If set, will pad the sequence to a multiple of the provided value.
  537. return_tensors (`str`):
  538. The type of Tensor to return. Allowable values are "np", or "pt".
  539. seed (`int`, *optional*):
  540. The seed to use for the random number generator for masking. If not provided, the global RNG will be used.
  541. <Tip>
  542. For best performance, this data collator should be used with a dataset having items that are dictionaries or
  543. BatchEncoding, with the `"special_tokens_mask"` key, as returned by a [`PreTrainedTokenizer`] or a
  544. [`PreTrainedTokenizerFast`] with the argument `return_special_tokens_mask=True`.
  545. <Example Options and Expectations>
  546. 1. Default Behavior:
  547. - `mask_replace_prob=0.8`, `random_replace_prob=0.1`.
  548. - Expect 80% of masked tokens replaced with `[MASK]`, 10% replaced with random tokens, and 10% left unchanged.
  549. 2. All masked tokens replaced by `[MASK]`:
  550. - `mask_replace_prob=1.0`, `random_replace_prob=0.0`.
  551. - Expect all masked tokens to be replaced with `[MASK]`. No tokens are left unchanged or replaced with random tokens.
  552. 3. No `[MASK]` replacement, only random tokens:
  553. - `mask_replace_prob=0.0`, `random_replace_prob=1.0`.
  554. - Expect all masked tokens to be replaced with random tokens. No `[MASK]` replacements or unchanged tokens.
  555. 4. Balanced replacement:
  556. - `mask_replace_prob=0.5`, `random_replace_prob=0.4`.
  557. - Expect 50% of masked tokens replaced with `[MASK]`, 40% replaced with random tokens, and 10% left unchanged.
  558. Note:
  559. The sum of `mask_replace_prob` and `random_replace_prob` must not exceed 1. If their sum is less than 1, the
  560. remaining proportion will consist of masked tokens left unchanged.
  561. </Tip>
  562. """
  563. tokenizer: PreTrainedTokenizerBase
  564. mlm: bool = True
  565. whole_word_mask: bool = False
  566. mlm_probability: float | None = 0.15
  567. mask_replace_prob: float = 0.8
  568. random_replace_prob: float = 0.1
  569. pad_to_multiple_of: int | None = None
  570. return_tensors: str = "pt"
  571. seed: int | None = None
  572. def __post_init__(self):
  573. if self.mlm:
  574. if self.tokenizer.mask_token is None:
  575. raise ValueError(
  576. "This tokenizer does not have a mask token which is necessary for masked language modeling. "
  577. "You should pass `mlm=False` to train on causal language modeling instead."
  578. )
  579. if self.mlm_probability is None or self.mlm_probability < 0 or self.mlm_probability > 1:
  580. raise ValueError("mlm_probability should be between 0 and 1.")
  581. self.mlm_probability = float(self.mlm_probability)
  582. elif self.whole_word_mask:
  583. raise ValueError(
  584. "Whole word masking can only be used with mlm=True."
  585. "If you want to use whole word masking, please set mlm=True."
  586. )
  587. if self.mask_replace_prob + self.random_replace_prob > 1:
  588. raise ValueError("The sum of mask_replace_prob and random_replace_prob should not exceed 1")
  589. if self.mask_replace_prob < 0 or self.mask_replace_prob > 1:
  590. raise ValueError("mask_replace_prob should be between 0 and 1.")
  591. if self.random_replace_prob < 0 or self.random_replace_prob > 1:
  592. raise ValueError("random_replace_prob should be between 0 and 1.")
  593. if self.whole_word_mask:
  594. if not self.tokenizer.is_fast:
  595. warnings.warn(
  596. "Whole word masking depends on offset mapping which is only natively available with fast tokenizers.",
  597. UserWarning,
  598. )
  599. if self.mask_replace_prob < 1:
  600. warnings.warn(
  601. "Random token replacement is not supported with whole word masking. "
  602. "Setting mask_replace_prob to 1.",
  603. )
  604. self.mask_replace_prob = 1
  605. self.random_replace_prob = 0
  606. self.mask_replace_prob = float(self.mask_replace_prob)
  607. self.random_replace_prob = float(self.random_replace_prob)
  608. self.generator = None
  609. def get_generator(self, seed):
  610. if self.return_tensors == "pt":
  611. import torch
  612. return torch.Generator().manual_seed(seed)
  613. else:
  614. return np.random.default_rng(seed)
  615. def create_rng(self):
  616. if mp.current_process().name == "MainProcess":
  617. # If we are in the main process, we create a generator object with the seed
  618. self.generator = self.get_generator(self.seed)
  619. else:
  620. # If we are in a worker process (i.e using multiprocessing), we need to set a unique seed for each
  621. # worker's generator, generated as the main seed + the worker's ID.
  622. # (https://pytorch.org/docs/stable/data.html#randomness-in-multi-process-data-loading)
  623. # Only PyTorch DataLoader allows us to access the worker ID, and so we check for this.
  624. import torch
  625. worker_info = torch.utils.data.get_worker_info()
  626. if worker_info is None:
  627. error_string = (
  628. "Worker process information is not available for seeding the generator. This may be because",
  629. "you are using multiprocessing without using a PyTorch DataLoader. The `seed` parameter can",
  630. "only be used when using multiprocessing with a PyTorch DataLoader. Please either use a",
  631. "single process or use a PyTorch DataLoader with multiple workers.",
  632. )
  633. raise ValueError(error_string)
  634. self.generator = self.get_generator(self.seed + worker_info.id)
  635. def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
  636. # Handle dict or lists with proper padding and conversion to tensor.
  637. if self.seed and self.generator is None:
  638. # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
  639. # If no seed supplied, we will use the global RNG
  640. self.create_rng()
  641. if isinstance(examples[0], Mapping):
  642. batch = pad_without_fast_tokenizer_warning(
  643. self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
  644. )
  645. else:
  646. batch = {
  647. "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  648. }
  649. # If special token mask has been preprocessed, pop it from the dict.
  650. special_tokens_mask = batch.pop("special_tokens_mask", None)
  651. offset_mapping = batch.pop("offset_mapping", None)
  652. if self.mlm:
  653. batch["input_ids"], batch["labels"] = self.torch_mask_tokens(
  654. batch["input_ids"], special_tokens_mask=special_tokens_mask, offset_mapping=offset_mapping
  655. )
  656. else:
  657. labels = batch["input_ids"].clone()
  658. if self.tokenizer.pad_token_id is not None:
  659. labels[labels == self.tokenizer.pad_token_id] = -100
  660. batch["labels"] = labels
  661. return batch
  662. def torch_mask_tokens(
  663. self, inputs: Any, special_tokens_mask: Any | None = None, offset_mapping: Any | None = None
  664. ) -> tuple[Any, Any]:
  665. """
  666. Prepare masked tokens inputs/labels for masked language modeling.
  667. """
  668. import torch
  669. labels = inputs.clone()
  670. # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
  671. probability_matrix = torch.full(labels.shape, self.mlm_probability)
  672. if special_tokens_mask is None:
  673. special_tokens_mask = [
  674. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  675. ]
  676. if self.whole_word_mask:
  677. word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
  678. to_numpy(offset_mapping), to_numpy(special_tokens_mask)
  679. )
  680. no_mask_mask = torch.tensor(no_mask_mask, dtype=torch.bool)
  681. else:
  682. no_mask_mask = (
  683. special_tokens_mask.bool()
  684. if isinstance(special_tokens_mask, torch.Tensor)
  685. else torch.tensor(special_tokens_mask, dtype=torch.bool)
  686. )
  687. probability_matrix.masked_fill_(no_mask_mask, value=0.0)
  688. masked_indices = torch.bernoulli(probability_matrix, generator=self.generator).bool()
  689. if self.whole_word_mask:
  690. masked_indices = torch.BoolTensor(self._whole_word_mask(word_ids, masked_indices))
  691. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  692. # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  693. indices_replaced = (
  694. torch.bernoulli(torch.full(labels.shape, self.mask_replace_prob), generator=self.generator).bool()
  695. & masked_indices
  696. )
  697. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  698. if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
  699. return inputs, labels
  700. remaining_prob = 1 - self.mask_replace_prob
  701. # scaling the random_replace_prob to the remaining probability for example if
  702. # mask_replace_prob = 0.8 and random_replace_prob = 0.1,
  703. # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
  704. random_replace_prob_scaled = self.random_replace_prob / remaining_prob
  705. # random_replace_prob% of the time, we replace masked input tokens with random word
  706. indices_random = (
  707. torch.bernoulli(torch.full(labels.shape, random_replace_prob_scaled), generator=self.generator).bool()
  708. & masked_indices
  709. & ~indices_replaced
  710. )
  711. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long, generator=self.generator)
  712. inputs[indices_random] = random_words[indices_random]
  713. # The rest of the time ((1-random_replace_prob-mask_replace_prob)% of the time) we keep the masked input tokens unchanged
  714. return inputs, labels
  715. def numpy_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
  716. # Handle dict or lists with proper padding and conversion to tensor.
  717. if self.seed and self.generator is None:
  718. # If we have a seed, we need to create a generator object. Subsequent calls to this function will use the same generator.
  719. # If no seed supplied, we will use the global RNG
  720. self.create_rng()
  721. if isinstance(examples[0], Mapping):
  722. batch = pad_without_fast_tokenizer_warning(
  723. self.tokenizer, examples, return_tensors="np", pad_to_multiple_of=self.pad_to_multiple_of
  724. )
  725. else:
  726. batch = {
  727. "input_ids": _numpy_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
  728. }
  729. # If special token mask has been preprocessed, pop it from the dict.
  730. special_tokens_mask = batch.pop("special_tokens_mask", None)
  731. offset_mapping = batch.pop("offset_mapping", None)
  732. if self.mlm:
  733. batch["input_ids"], batch["labels"] = self.numpy_mask_tokens(
  734. batch["input_ids"], special_tokens_mask=special_tokens_mask, offset_mapping=offset_mapping
  735. )
  736. else:
  737. labels = np.copy(batch["input_ids"])
  738. if self.tokenizer.pad_token_id is not None:
  739. labels[labels == self.tokenizer.pad_token_id] = -100
  740. batch["labels"] = labels
  741. return batch
  742. def numpy_mask_tokens(
  743. self,
  744. inputs: Any,
  745. special_tokens_mask: Any | None = None,
  746. offset_mapping: Any | None = None,
  747. ) -> tuple[Any, Any]:
  748. """
  749. Prepare masked tokens inputs/labels for masked language modeling.
  750. """
  751. labels = np.copy(inputs)
  752. # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
  753. probability_matrix = np.full(labels.shape, self.mlm_probability)
  754. if special_tokens_mask is None:
  755. special_tokens_mask = [
  756. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  757. ]
  758. if self.whole_word_mask:
  759. word_ids, no_mask_mask = self._calc_word_ids_and_prob_mask(
  760. to_numpy(offset_mapping), to_numpy(special_tokens_mask)
  761. )
  762. else:
  763. no_mask_mask = (
  764. special_tokens_mask.astype(bool)
  765. if isinstance(special_tokens_mask, np.ndarray)
  766. else np.array(special_tokens_mask, dtype=bool)
  767. )
  768. probability_matrix[no_mask_mask] = 0
  769. # Numpy doesn't have bernoulli, so we use a binomial with 1 trial
  770. if self.generator:
  771. masked_indices = self.generator.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
  772. else:
  773. masked_indices = np.random.binomial(1, probability_matrix, size=probability_matrix.shape).astype(bool)
  774. if self.whole_word_mask:
  775. masked_indices = self._whole_word_mask(word_ids, masked_indices)
  776. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  777. # mask_replace_prob% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  778. if self.generator:
  779. indices_replaced = (
  780. self.generator.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
  781. )
  782. else:
  783. indices_replaced = (
  784. np.random.binomial(1, self.mask_replace_prob, size=labels.shape).astype(bool) & masked_indices
  785. )
  786. inputs[indices_replaced] = self.tokenizer.mask_token_id
  787. if self.mask_replace_prob == 1 or self.random_replace_prob == 0:
  788. return inputs, labels
  789. remaining_prob = 1 - self.mask_replace_prob
  790. # scaling the random_replace_prob to the remaining probability for example if
  791. # mask_replace_prob = 0.8 and random_replace_prob = 0.1,
  792. # then random_replace_prob_scaled = 0.1 / 0.2 = 0.5
  793. random_replace_prob_scaled = self.random_replace_prob / remaining_prob
  794. if self.generator:
  795. indices_random = (
  796. self.generator.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
  797. & masked_indices
  798. & ~indices_replaced
  799. )
  800. random_words = self.generator.integers(
  801. low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
  802. )
  803. else:
  804. indices_random = (
  805. np.random.binomial(1, random_replace_prob_scaled, size=labels.shape).astype(bool)
  806. & masked_indices
  807. & ~indices_replaced
  808. )
  809. random_words = np.random.randint(
  810. low=0, high=len(self.tokenizer), size=np.count_nonzero(indices_random), dtype=np.int64
  811. )
  812. inputs[indices_random] = random_words
  813. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  814. return inputs, labels
  815. @staticmethod
  816. def _calc_word_ids_and_prob_mask(
  817. offsets: np.ndarray[np.ndarray[tuple[int, int]]], special_tokens_mask: np.ndarray[np.ndarray[int]]
  818. ) -> tuple[np.ndarray[np.ndarray[int]], np.ndarray[np.ndarray[int]]]:
  819. """
  820. Map tokens to word ids and create mask of tokens to not mask.
  821. Tokens that are part of the same word will have the same word id and we will only
  822. set a mask probability for the first token of each word.
  823. """
  824. token_starts = offsets[:, :, 0]
  825. token_ends = offsets[:, :, 1]
  826. prev_token_ends = np.roll(token_ends, 1, axis=1)
  827. prev_token_ends[:, 0] = -1 # First token has no previous token
  828. prev_token_special = np.roll(special_tokens_mask, 1, axis=1)
  829. prev_token_special[:, 0] = 0
  830. # Not special token AND (gap from previous or previous token was special)
  831. special_tokens_mask = special_tokens_mask.astype(bool)
  832. is_new_word = (~special_tokens_mask) & ((token_starts != prev_token_ends) | (prev_token_special == 1))
  833. word_ids = np.cumsum(is_new_word, axis=1)
  834. word_ids[special_tokens_mask] = -1
  835. prob_mask = ~is_new_word
  836. return word_ids, prob_mask
  837. @staticmethod
  838. def _whole_word_mask(word_ids: np.ndarray[np.ndarray[int]], mask: Any) -> Any:
  839. """
  840. Mask whole words based on word ids and mask.
  841. """
  842. mask = to_numpy(mask)
  843. valid_ids = word_ids != -1
  844. # Create 3D mask where [batch, token_i, token_j] is True if token_i and token_j are the same word
  845. same_word = (word_ids[:, :, None] == word_ids[:, None, :]) & valid_ids[:, :, None] & valid_ids[:, None, :]
  846. # For each token, set True if any token in the same word is masked
  847. return np.any(same_word & mask[:, None, :], axis=2)
  848. @dataclass
  849. class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
  850. """
  851. Data collator used for language modeling that masks entire words.
  852. - collates batches of tensors, honoring their tokenizer's pad_token
  853. - preprocesses batches for masked language modeling
  854. """
  855. def __init__(self, *args, **kwargs):
  856. warnings.warn(
  857. "DataCollatorForWholeWordMask is deprecated and will be removed in a future version, you can now use "
  858. "DataCollatorForLanguageModeling with whole_word_mask=True instead.",
  859. FutureWarning,
  860. )
  861. super().__init__(*args, **kwargs)
  862. self.mlm = True # Force masked language modeling
  863. self.whole_word_mask = True # Force whole word masking
  864. def tolist(x) -> list[Any]:
  865. if isinstance(x, list):
  866. return x
  867. elif hasattr(x, "numpy"):
  868. x = x.numpy()
  869. return x.tolist()
  870. def to_numpy(x) -> np.ndarray[Any]:
  871. if isinstance(x, np.ndarray):
  872. return x
  873. elif hasattr(x, "detach"):
  874. return x.detach().cpu().numpy()
  875. else:
  876. return np.array(x)
  877. @dataclass
  878. class DataCollatorForSOP(DataCollatorForLanguageModeling):
  879. """
  880. Data collator used for sentence order prediction task.
  881. - collates batches of tensors, honoring their tokenizer's pad_token
  882. - preprocesses batches for both masked language modeling and sentence order prediction
  883. """
  884. def __init__(self, *args, **kwargs):
  885. warnings.warn(
  886. "DataCollatorForSOP is deprecated and will be removed in a future version, you can now use "
  887. "DataCollatorForLanguageModeling instead.",
  888. FutureWarning,
  889. )
  890. def __call__(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
  891. import torch
  892. from torch.nn.utils.rnn import pad_sequence
  893. input_ids = [example["input_ids"] for example in examples]
  894. input_ids = _torch_collate_batch(input_ids, self.tokenizer)
  895. input_ids, labels, attention_mask = self.mask_tokens(input_ids)
  896. token_type_ids = [example["token_type_ids"] for example in examples]
  897. # size of segment_ids varied because randomness, padding zero to the end as the original implementation
  898. token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
  899. sop_label_list = [example["sentence_order_label"] for example in examples]
  900. sentence_order_label = torch.stack(sop_label_list)
  901. return {
  902. "input_ids": input_ids,
  903. "labels": labels,
  904. "attention_mask": attention_mask,
  905. "token_type_ids": token_type_ids,
  906. "sentence_order_label": sentence_order_label,
  907. }
  908. def mask_tokens(self, inputs: Any) -> tuple[Any, Any, Any]:
  909. """
  910. Prepare masked tokens inputs/labels/attention_mask for masked language modeling: 80% MASK, 10% random, 10%
  911. original. N-gram not applied yet.
  912. """
  913. import torch
  914. if self.tokenizer.mask_token is None:
  915. raise ValueError(
  916. "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the"
  917. " --mlm flag if you want to use this tokenizer."
  918. )
  919. labels = inputs.clone()
  920. # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
  921. probability_matrix = torch.full(labels.shape, self.mlm_probability)
  922. special_tokens_mask = [
  923. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  924. ]
  925. probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
  926. if self.tokenizer.pad_token is not None:
  927. padding_mask = labels.eq(self.tokenizer.pad_token_id)
  928. probability_matrix.masked_fill_(padding_mask, value=0.0)
  929. masked_indices = torch.bernoulli(probability_matrix).bool()
  930. # probability be `1` (masked), however in albert model attention mask `0` means masked, revert the value
  931. attention_mask = (~masked_indices).float()
  932. if self.tokenizer.pad_token is not None:
  933. attention_padding_mask = labels.eq(self.tokenizer.pad_token_id)
  934. attention_mask.masked_fill_(attention_padding_mask, value=1.0)
  935. labels[~masked_indices] = -100 # We only compute loss on masked tokens, -100 is default for CE compute
  936. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  937. indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
  938. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  939. # 10% of the time, we replace masked input tokens with random word
  940. indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  941. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
  942. inputs[indices_random] = random_words[indices_random]
  943. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  944. return inputs, labels, attention_mask
  945. @dataclass
  946. class DataCollatorForPermutationLanguageModeling(DataCollatorMixin):
  947. """
  948. Data collator used for permutation language modeling.
  949. - collates batches of tensors, honoring their tokenizer's pad_token
  950. - preprocesses batches for permutation language modeling with procedures specific to XLNet
  951. """
  952. tokenizer: PreTrainedTokenizerBase
  953. plm_probability: float = 1 / 6
  954. max_span_length: int = 5 # maximum length of a span of masked tokens
  955. return_tensors: str = "pt"
  956. def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
  957. if isinstance(examples[0], Mapping):
  958. examples = [e["input_ids"] for e in examples]
  959. batch = _torch_collate_batch(examples, self.tokenizer)
  960. inputs, perm_mask, target_mapping, labels = self.torch_mask_tokens(batch)
  961. return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
  962. def numpy_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
  963. if isinstance(examples[0], Mapping):
  964. examples = [e["input_ids"] for e in examples]
  965. batch = _numpy_collate_batch(examples, self.tokenizer)
  966. inputs, perm_mask, target_mapping, labels = self.numpy_mask_tokens(batch)
  967. return {"input_ids": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping, "labels": labels}
  968. def torch_mask_tokens(self, inputs: Any) -> tuple[Any, Any, Any, Any]:
  969. """
  970. The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
  971. 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  972. 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  973. 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
  974. masked
  975. 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
  976. span_length]` and mask tokens `start_index:start_index + span_length`
  977. 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
  978. sequence to be processed), repeat from Step 1.
  979. """
  980. import torch
  981. if self.tokenizer.mask_token is None:
  982. raise ValueError(
  983. "This tokenizer does not have a mask token which is necessary for permutation language modeling."
  984. " Please add a mask token if you want to use this tokenizer."
  985. )
  986. if inputs.size(1) % 2 != 0:
  987. raise ValueError(
  988. "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
  989. " relevant comments in source code for details."
  990. )
  991. labels = inputs.clone()
  992. # Creating the mask and target_mapping tensors
  993. masked_indices = torch.full(labels.shape, 0, dtype=torch.bool)
  994. target_mapping = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
  995. for i in range(labels.size(0)):
  996. # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  997. cur_len = 0
  998. max_len = labels.size(1)
  999. while cur_len < max_len:
  1000. # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1001. span_length = torch.randint(1, self.max_span_length + 1, (1,)).item()
  1002. # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
  1003. context_length = int(span_length / self.plm_probability)
  1004. # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
  1005. start_index = cur_len + torch.randint(context_length - span_length + 1, (1,)).item()
  1006. masked_indices[i, start_index : start_index + span_length] = 1
  1007. # Set `cur_len = cur_len + context_length`
  1008. cur_len += context_length
  1009. # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
  1010. # the i-th predict corresponds to the i-th token.
  1011. target_mapping[i] = torch.eye(labels.size(1))
  1012. special_tokens_mask = torch.tensor(
  1013. [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
  1014. dtype=torch.bool,
  1015. )
  1016. masked_indices.masked_fill_(special_tokens_mask, value=0.0)
  1017. if self.tokenizer.pad_token is not None:
  1018. padding_mask = labels.eq(self.tokenizer.pad_token_id)
  1019. masked_indices.masked_fill_(padding_mask, value=0.0)
  1020. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
  1021. non_func_mask = ~(padding_mask | special_tokens_mask)
  1022. inputs[masked_indices] = self.tokenizer.mask_token_id
  1023. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  1024. perm_mask = torch.zeros((labels.size(0), labels.size(1), labels.size(1)), dtype=torch.float32)
  1025. for i in range(labels.size(0)):
  1026. # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
  1027. # determine which tokens a given token can attend to (encoded in `perm_mask`).
  1028. # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
  1029. # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
  1030. # we assume that reused length is half of sequence length and permutation length is equal to reused length.
  1031. # This requires that the sequence length be even.
  1032. # Create a linear factorisation order
  1033. perm_index = torch.arange(labels.size(1))
  1034. # Split this into two halves, assuming that half the sequence is reused each time
  1035. perm_index = perm_index.reshape((-1, labels.size(1) // 2)).transpose(0, 1)
  1036. # Permute the two halves such that they do not cross over
  1037. perm_index = perm_index[torch.randperm(labels.size(1) // 2)]
  1038. # Flatten this out into the desired permuted factorisation order
  1039. perm_index = torch.flatten(perm_index.transpose(0, 1))
  1040. # Set the permutation indices of non-masked (non-functional) tokens to the
  1041. # smallest index (-1) so that:
  1042. # (1) They can be seen by all other positions
  1043. # (2) They cannot see masked positions, so there won't be information leak
  1044. perm_index.masked_fill_(~masked_indices[i] & non_func_mask[i], -1)
  1045. # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
  1046. # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
  1047. # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
  1048. perm_mask[i] = (
  1049. perm_index.reshape((labels.size(1), 1)) <= perm_index.reshape((1, labels.size(1)))
  1050. ) & masked_indices[i]
  1051. return inputs.long(), perm_mask, target_mapping, labels.long()
  1052. def numpy_mask_tokens(self, inputs: Any) -> tuple[Any, Any, Any, Any]:
  1053. """
  1054. The masked tokens to be predicted for a particular sequence are determined by the following algorithm:
  1055. 0. Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1056. 1. Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1057. 2. Reserve a context of length `context_length = span_length / plm_probability` to surround span to be
  1058. masked
  1059. 3. Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length -
  1060. span_length]` and mask tokens `start_index:start_index + span_length`
  1061. 4. Set `cur_len = cur_len + context_length`. If `cur_len < max_len` (i.e. there are tokens remaining in the
  1062. sequence to be processed), repeat from Step 1.
  1063. """
  1064. if self.tokenizer.mask_token is None:
  1065. raise ValueError(
  1066. "This tokenizer does not have a mask token which is necessary for permutation language modeling."
  1067. " Please add a mask token if you want to use this tokenizer."
  1068. )
  1069. if inputs.shape[1] % 2 != 0:
  1070. raise ValueError(
  1071. "This collator requires that sequence lengths be even to create a leakage-free perm_mask. Please see"
  1072. " relevant comments in source code for details."
  1073. )
  1074. labels = np.copy(inputs)
  1075. # Creating the mask and target_mapping tensors
  1076. masked_indices = np.full(labels.shape, 0, dtype=bool)
  1077. target_mapping = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
  1078. for i in range(labels.shape[0]):
  1079. # Start from the beginning of the sequence by setting `cur_len = 0` (number of tokens processed so far).
  1080. cur_len = 0
  1081. max_len = labels.shape[1]
  1082. while cur_len < max_len:
  1083. # Sample a `span_length` from the interval `[1, max_span_length]` (length of span of tokens to be masked)
  1084. span_length = randint(1, self.max_span_length + 1)
  1085. # Reserve a context of length `context_length = span_length / plm_probability` to surround the span to be masked
  1086. context_length = int(span_length / self.plm_probability)
  1087. # Sample a starting point `start_index` from the interval `[cur_len, cur_len + context_length - span_length]` and mask tokens `start_index:start_index + span_length`
  1088. start_index = cur_len + randint(0, context_length - span_length + 1)
  1089. masked_indices[i, start_index : start_index + span_length] = 1
  1090. # Set `cur_len = cur_len + context_length`
  1091. cur_len += context_length
  1092. # Since we're replacing non-masked tokens with -100 in the labels tensor instead of skipping them altogether,
  1093. # the i-th predict corresponds to the i-th token.
  1094. target_mapping[i] = np.eye(labels.shape[1])
  1095. special_tokens_mask = np.array(
  1096. [self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()],
  1097. dtype=bool,
  1098. )
  1099. masked_indices[special_tokens_mask] = 0
  1100. if self.tokenizer.pad_token is not None:
  1101. padding_mask = labels == self.tokenizer.pad_token_id
  1102. masked_indices[padding_mask] = 0.0
  1103. # Mask indicating non-functional tokens, where functional tokens are [SEP], [CLS], padding, etc.
  1104. non_func_mask = ~(padding_mask | special_tokens_mask)
  1105. inputs[masked_indices] = self.tokenizer.mask_token_id
  1106. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  1107. perm_mask = np.zeros((labels.shape[0], labels.shape[1], labels.shape[1]), dtype=np.float32)
  1108. for i in range(labels.shape[0]):
  1109. # Generate permutation indices i.e. sample a random factorisation order for the sequence. This will
  1110. # determine which tokens a given token can attend to (encoded in `perm_mask`).
  1111. # Note: Length of token sequence being permuted has to be less than or equal to reused sequence length
  1112. # (see documentation for `mems`), otherwise information may leak through due to reuse. In this implementation,
  1113. # we assume that reused length is half of sequence length and permutation length is equal to reused length.
  1114. # This requires that the sequence length be even.
  1115. # Create a linear factorisation order
  1116. perm_index = np.arange(labels.shape[1])
  1117. # Split this into two halves, assuming that half the sequence is reused each time
  1118. perm_index = perm_index.reshape((-1, labels.shape[1] // 2)).T
  1119. # Permute the two halves such that they do not cross over
  1120. np.random.shuffle(perm_index)
  1121. # Flatten this out into the desired permuted factorisation order
  1122. perm_index = perm_index.T.flatten()
  1123. # Set the permutation indices of non-masked (non-functional) tokens to the
  1124. # smallest index (-1) so that:
  1125. # (1) They can be seen by all other positions
  1126. # (2) They cannot see masked positions, so there won't be information leak
  1127. perm_index[~masked_indices[i] & non_func_mask[i]] = -1
  1128. # The logic for whether the i-th token can attend on the j-th token based on the factorisation order:
  1129. # 0 (can attend): If perm_index[i] > perm_index[j] or j is neither masked nor a functional token
  1130. # 1 (cannot attend): If perm_index[i] <= perm_index[j] and j is either masked or a functional token
  1131. perm_mask[i] = (
  1132. perm_index.reshape((labels.shape[1], 1)) <= perm_index.reshape((1, labels.shape[1]))
  1133. ) & masked_indices[i]
  1134. return inputs.astype(np.int64), perm_mask, target_mapping, labels.astype(np.int64)
  1135. @dataclass
  1136. class DataCollatorWithFlattening(DefaultDataCollator):
  1137. """
  1138. Data collator used for padding free approach. Does the following:
  1139. - concatenates the entire mini batch into single long sequence of shape [1, total_tokens]
  1140. - uses `separator_id` to separate sequences within the concatenated `labels`, default value is -100
  1141. - no padding will be added, returns `input_ids`, `labels` and `position_ids` by default
  1142. - optionally returns the kwargs contained in FlashAttentionKwargs
  1143. - optionally returns seq_idx indicating which sequence each token belongs to
  1144. <Tip warning={true}>
  1145. Using `DataCollatorWithFlattening` will flatten the entire mini batch into single long sequence.
  1146. Make sure your attention computation is able to handle it!
  1147. </Tip>
  1148. """
  1149. def __init__(
  1150. self,
  1151. *args,
  1152. return_position_ids=True,
  1153. separator_id=-100,
  1154. return_flash_attn_kwargs=False,
  1155. return_seq_idx=False,
  1156. **kwargs,
  1157. ):
  1158. super().__init__(*args, **kwargs)
  1159. self.return_position_ids = return_position_ids
  1160. self.separator_id = separator_id
  1161. self.return_flash_attn_kwargs = return_flash_attn_kwargs
  1162. self.return_seq_idx = return_seq_idx
  1163. self._int_64_keys = {"labels", "position_ids", "input_ids"}
  1164. self._batch_dim_keys = {"labels", "position_ids", "input_ids", "seq_idx"}
  1165. self._py_int_keys = {"max_length_q", "max_length_k"}
  1166. def __call__(self, features, return_tensors=None, separator_id=None):
  1167. if return_tensors is None:
  1168. return_tensors = self.return_tensors
  1169. if separator_id is None:
  1170. separator_id = self.separator_id
  1171. is_labels_provided = "labels" in features[0]
  1172. batch = {"input_ids": [], "labels": []}
  1173. if self.return_position_ids:
  1174. batch.update({"position_ids": []})
  1175. if self.return_seq_idx:
  1176. batch.update({"seq_idx": []})
  1177. if self.return_flash_attn_kwargs:
  1178. cu_seq_lens = [0]
  1179. max_length = 0
  1180. for seq_idx, sample in enumerate(features):
  1181. input_ids = sample["input_ids"]
  1182. # Convert to list if tensor
  1183. if hasattr(input_ids, "tolist"):
  1184. input_ids = input_ids.tolist()
  1185. batch["input_ids"] += input_ids
  1186. if is_labels_provided:
  1187. labels = sample["labels"]
  1188. # Convert to list if tensor
  1189. if hasattr(labels, "tolist"):
  1190. labels = labels.tolist()
  1191. batch["labels"] += [separator_id] + labels[1:]
  1192. else:
  1193. batch["labels"] += [separator_id] + input_ids[1:]
  1194. if self.return_position_ids:
  1195. batch["position_ids"] += list(range(len(input_ids)))
  1196. if self.return_seq_idx:
  1197. batch["seq_idx"] += [seq_idx for _ in range(len(input_ids))]
  1198. if self.return_flash_attn_kwargs:
  1199. cu_seq_lens.append(cu_seq_lens[-1] + len(input_ids))
  1200. max_length = max(max_length, len(input_ids))
  1201. if self.return_flash_attn_kwargs:
  1202. batch["cu_seq_lens_q"] = batch["cu_seq_lens_k"] = cu_seq_lens
  1203. batch["max_length_q"] = batch["max_length_k"] = max_length
  1204. # FlashAttentionKwargs and seq_idx are expected to be int32s.
  1205. if return_tensors == "pt":
  1206. import torch
  1207. data_cls = torch.tensor
  1208. dtype_64 = torch.int64
  1209. dtype_32 = torch.int32
  1210. elif return_tensors == "np":
  1211. data_cls = np.array
  1212. dtype_64 = np.int64
  1213. dtype_32 = np.int32
  1214. else:
  1215. raise ValueError(f'return_tensors must be one of ("pt", "np"), {return_tensors=} not supported')
  1216. for k, v in batch.items():
  1217. if k in self._batch_dim_keys:
  1218. v = [v]
  1219. # Flash attention max_len_{q,k} are python ints
  1220. if k not in self._py_int_keys:
  1221. batch[k] = data_cls(v, dtype=dtype_64 if k in self._int_64_keys else dtype_32)
  1222. return batch