trainer_seq2seq.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. from collections.abc import Callable
  16. from copy import deepcopy
  17. from pathlib import Path
  18. from typing import TYPE_CHECKING, Any, Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.utils.data import Dataset
  22. from .generation.configuration_utils import GenerationConfig
  23. from .integrations.deepspeed import is_deepspeed_zero3_enabled
  24. from .integrations.fsdp import is_fsdp_managed_module
  25. from .trainer import Trainer
  26. from .utils import is_datasets_available, logging
  27. if torch.distributed.is_available():
  28. from torch.distributed.fsdp import FullyShardedDataParallel
  29. if is_datasets_available():
  30. import datasets
  31. if TYPE_CHECKING:
  32. from torch.utils.data import IterableDataset
  33. from .data.data_collator import DataCollator
  34. from .feature_extraction_utils import FeatureExtractionMixin
  35. from .image_processing_utils import BaseImageProcessor
  36. from .modeling_utils import PreTrainedModel
  37. from .processing_utils import ProcessorMixin
  38. from .tokenization_utils_base import PreTrainedTokenizerBase
  39. from .trainer_callback import TrainerCallback
  40. from .trainer_utils import EvalPrediction, PredictionOutput
  41. from .training_args import TrainingArguments
  42. logger = logging.get_logger(__name__)
  43. class Seq2SeqTrainer(Trainer):
  44. def __init__(
  45. self,
  46. model: Union["PreTrainedModel", nn.Module] | None = None,
  47. args: Optional["TrainingArguments"] = None,
  48. data_collator: Optional["DataCollator"] = None,
  49. train_dataset: Union[Dataset, "IterableDataset", "datasets.Dataset"] | None = None,
  50. eval_dataset: Dataset | dict[str, Dataset] | None = None,
  51. processing_class: Union[
  52. "PreTrainedTokenizerBase", "BaseImageProcessor", "FeatureExtractionMixin", "ProcessorMixin"
  53. ]
  54. | None = None,
  55. model_init: Callable[[], "PreTrainedModel"] | None = None,
  56. compute_loss_func: Callable | None = None,
  57. compute_metrics: Callable[["EvalPrediction"], dict] | None = None,
  58. callbacks: list["TrainerCallback"] | None = None,
  59. optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None),
  60. preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
  61. ):
  62. super().__init__(
  63. model=model,
  64. args=args,
  65. data_collator=data_collator,
  66. train_dataset=train_dataset,
  67. eval_dataset=eval_dataset,
  68. processing_class=processing_class,
  69. model_init=model_init,
  70. compute_loss_func=compute_loss_func,
  71. compute_metrics=compute_metrics,
  72. callbacks=callbacks,
  73. optimizers=optimizers,
  74. preprocess_logits_for_metrics=preprocess_logits_for_metrics,
  75. )
  76. # Override self.model.generation_config if a GenerationConfig is specified in args.
  77. # Priority: args.generation_config > model.generation_config > default GenerationConfig.
  78. if self.args.generation_config is not None:
  79. gen_config = self.load_generation_config(self.args.generation_config)
  80. self.model.generation_config = gen_config
  81. @staticmethod
  82. def load_generation_config(gen_config_arg: str | GenerationConfig) -> GenerationConfig:
  83. """
  84. Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments.
  85. Args:
  86. gen_config_arg (`str` or [`~generation.GenerationConfig]`):
  87. `Seq2SeqTrainingArguments.generation_config` argument.
  88. Returns:
  89. A `~generation.GenerationConfig`.
  90. """
  91. # GenerationConfig provided, nothing to do
  92. if isinstance(gen_config_arg, GenerationConfig):
  93. gen_config = deepcopy(gen_config_arg)
  94. else:
  95. # str or Path
  96. pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg
  97. config_file_name = None
  98. # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL
  99. # This step is required in order to determine config_file_name
  100. if pretrained_model_name.is_file():
  101. config_file_name = pretrained_model_name.name
  102. pretrained_model_name = pretrained_model_name.parent
  103. # dir path
  104. elif pretrained_model_name.is_dir():
  105. pass
  106. # model id or URL
  107. else:
  108. pretrained_model_name = gen_config_arg
  109. gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name)
  110. # Strict validation to fail early. `GenerationConfig.save_pretrained()`, run at the end of training, throws
  111. # an exception if there are warnings at validation time.
  112. try:
  113. gen_config.validate(strict=True)
  114. except ValueError as exc:
  115. raise ValueError(str(exc) + "\n\nFix these issues to train your model.")
  116. return gen_config
  117. def evaluate(
  118. self,
  119. eval_dataset: Dataset | None = None,
  120. ignore_keys: list[str] | None = None,
  121. metric_key_prefix: str = "eval",
  122. **gen_kwargs,
  123. ) -> dict[str, float]:
  124. """
  125. Run evaluation and returns metrics.
  126. The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
  127. (pass it to the init `compute_metrics` argument).
  128. You can also subclass and override this method to inject custom behavior.
  129. Args:
  130. eval_dataset (`Dataset`, *optional*):
  131. Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
  132. not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
  133. method.
  134. ignore_keys (`list[str]`, *optional*):
  135. A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  136. gathering predictions.
  137. metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
  138. An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
  139. "eval_bleu" if the prefix is `"eval"` (default)
  140. max_length (`int`, *optional*):
  141. The maximum target length to use when predicting with the generate method.
  142. num_beams (`int`, *optional*):
  143. Number of beams for beam search that will be used when predicting with the generate method. 1 means no
  144. beam search.
  145. gen_kwargs:
  146. Additional `generate` specific kwargs.
  147. Returns:
  148. A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
  149. dictionary also contains the epoch number which comes from the training state.
  150. """
  151. gen_kwargs = gen_kwargs.copy()
  152. # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
  153. # training args
  154. if (
  155. gen_kwargs.get("max_length") is None
  156. and gen_kwargs.get("max_new_tokens") is None
  157. and self.args.generation_max_length is not None
  158. ):
  159. gen_kwargs["max_length"] = self.args.generation_max_length
  160. if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
  161. gen_kwargs["num_beams"] = self.args.generation_num_beams
  162. # We don't want to drop samples in general
  163. self.gather_function = self.accelerator.gather
  164. self._gen_kwargs = gen_kwargs
  165. return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  166. def predict(
  167. self,
  168. test_dataset: Dataset,
  169. ignore_keys: list[str] | None = None,
  170. metric_key_prefix: str = "test",
  171. **gen_kwargs,
  172. ) -> "PredictionOutput":
  173. """
  174. Run prediction and returns predictions and potential metrics.
  175. Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
  176. will also return metrics, like in `evaluate()`.
  177. Args:
  178. test_dataset (`Dataset`):
  179. Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
  180. `model.forward()` method are automatically removed. Has to implement the method `__len__`
  181. ignore_keys (`list[str]`, *optional*):
  182. A list of keys in the output of your model (if it is a dictionary) that should be ignored when
  183. gathering predictions.
  184. metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
  185. An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
  186. "eval_bleu" if the prefix is `"eval"` (default)
  187. max_length (`int`, *optional*):
  188. The maximum target length to use when predicting with the generate method.
  189. num_beams (`int`, *optional*):
  190. Number of beams for beam search that will be used when predicting with the generate method. 1 means no
  191. beam search.
  192. gen_kwargs:
  193. Additional `generate` specific kwargs.
  194. <Tip>
  195. If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
  196. padding in a token classification task) the predictions will be padded (on the right) to allow for
  197. concatenation into one array. The padding index is -100.
  198. </Tip>
  199. Returns: *NamedTuple* A namedtuple with the following keys:
  200. - predictions (`np.ndarray`): The predictions on `test_dataset`.
  201. - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
  202. - metrics (`dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
  203. labels).
  204. """
  205. gen_kwargs = gen_kwargs.copy()
  206. # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the
  207. # training args
  208. if (
  209. gen_kwargs.get("max_length") is None
  210. and gen_kwargs.get("max_new_tokens") is None
  211. and self.args.generation_max_length is not None
  212. ):
  213. gen_kwargs["max_length"] = self.args.generation_max_length
  214. if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None:
  215. gen_kwargs["num_beams"] = self.args.generation_num_beams
  216. self.gather_function = self.accelerator.gather
  217. self._gen_kwargs = gen_kwargs
  218. return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
  219. def prediction_step(
  220. self,
  221. model: nn.Module,
  222. inputs: dict[str, torch.Tensor | Any],
  223. prediction_loss_only: bool,
  224. ignore_keys: list[str] | None = None,
  225. **gen_kwargs,
  226. ) -> tuple[float | None, torch.Tensor | None, torch.Tensor | None]:
  227. """
  228. Perform an evaluation step on `model` using `inputs`.
  229. Subclass and override to inject custom behavior.
  230. Args:
  231. model (`nn.Module`):
  232. The model to evaluate.
  233. inputs (`dict[str, Union[torch.Tensor, Any]]`):
  234. The inputs and targets of the model.
  235. The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
  236. argument `labels`. Check your model's documentation for all accepted arguments.
  237. prediction_loss_only (`bool`):
  238. Whether or not to return the loss only.
  239. gen_kwargs:
  240. Additional `generate` specific kwargs.
  241. Return:
  242. tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
  243. labels (each being optional).
  244. """
  245. if not self.args.predict_with_generate or prediction_loss_only:
  246. return super().prediction_step(
  247. model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
  248. )
  249. has_labels = "labels" in inputs
  250. inputs = self._prepare_inputs(inputs)
  251. # Priority (handled in generate):
  252. # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
  253. if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
  254. gen_kwargs = self._gen_kwargs.copy()
  255. if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
  256. gen_kwargs.pop("num_beams")
  257. if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
  258. gen_kwargs.pop("max_length")
  259. default_synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self.model)
  260. gen_kwargs["synced_gpus"] = gen_kwargs.get("synced_gpus", default_synced_gpus)
  261. generation_inputs = inputs.copy()
  262. # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
  263. # (otherwise, it would continue generating from the padded `decoder_input_ids`)
  264. if (
  265. "labels" in generation_inputs
  266. and "decoder_input_ids" in generation_inputs
  267. and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
  268. ):
  269. generation_inputs = {
  270. k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
  271. }
  272. summon_full_params_context = (
  273. FullyShardedDataParallel.summon_full_params(self.model)
  274. if torch.distributed.is_available() and isinstance(self.model, FullyShardedDataParallel)
  275. else contextlib.nullcontext()
  276. )
  277. with summon_full_params_context:
  278. generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
  279. # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
  280. # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
  281. # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
  282. if self.model.generation_config._from_model_config:
  283. self.model.generation_config._from_model_config = False
  284. # Retrieves GenerationConfig from model.generation_config
  285. # Update with defaults because earlier the generation config used to be init
  286. # with default values. Now we init it with `None` and keep defaults for BC
  287. gen_config = self.model.generation_config
  288. default_gen_config = gen_config._get_default_generation_params()
  289. gen_config.update(**default_gen_config, defaults_only=True)
  290. # in case the batch is shorter than max length, the output should be padded
  291. if generated_tokens.shape[-1] < gen_config.max_length:
  292. generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
  293. elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
  294. generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)
  295. with torch.no_grad():
  296. if has_labels:
  297. with self.compute_loss_context_manager():
  298. outputs = model(**inputs)
  299. if self.label_smoother is not None:
  300. loss = self.label_smoother(outputs, inputs["labels"]).detach().mean()
  301. else:
  302. loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).detach().mean()
  303. else:
  304. loss = None
  305. if self.args.prediction_loss_only:
  306. return loss, None, None
  307. if has_labels:
  308. labels = inputs["labels"]
  309. if labels.shape[-1] < gen_config.max_length:
  310. labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
  311. elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
  312. labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
  313. else:
  314. labels = None
  315. return loss, generated_tokens, labels
  316. def _pad_tensors_to_max_len(self, tensor, max_length):
  317. if self.processing_class is not None and hasattr(self.processing_class, "pad_token_id"):
  318. # If PAD token is not defined at least EOS token has to be defined
  319. pad_token_id = (
  320. self.processing_class.pad_token_id
  321. if self.processing_class.pad_token_id is not None
  322. else self.processing_class.eos_token_id
  323. )
  324. else:
  325. if getattr(self.model.config, "pad_token_id", None) is not None:
  326. pad_token_id = self.model.config.pad_token_id
  327. else:
  328. raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
  329. padded_tensor = pad_token_id * torch.ones(
  330. (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
  331. )
  332. padded_tensor[:, : tensor.shape[-1]] = tensor
  333. return padded_tensor