bert.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import csv
  15. import logging
  16. import urllib
  17. from collections.abc import Iterator, Sequence
  18. from contextlib import contextmanager
  19. from typing import Any, Callable, List, Optional, Tuple, Union, cast
  20. import torch
  21. from torch import Tensor
  22. from torch.nn import Module
  23. from torch.utils.data import DataLoader
  24. from torchmetrics.functional.text.helper_embedding_metric import (
  25. TextDataset,
  26. TokenizedDataset,
  27. _check_shape_of_model_output,
  28. _get_progress_bar,
  29. _input_data_collator,
  30. _output_data_collator,
  31. _process_attention_mask_for_special_tokens,
  32. )
  33. from torchmetrics.utilities import rank_zero_warn
  34. from torchmetrics.utilities.checks import _SKIP_SLOW_DOCTEST, _try_proceed_with_timeout
  35. from torchmetrics.utilities.imports import _TQDM_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4
  36. @contextmanager
  37. def _ignore_log_warning() -> Iterator[None]:
  38. """Ignore irrelevant fine-tuning warning from transformers when loading the model for BertScore."""
  39. logger = logging.getLogger("transformers.modeling_utils")
  40. original_level = logger.getEffectiveLevel()
  41. try:
  42. logger.setLevel(logging.ERROR)
  43. yield
  44. finally:
  45. logger.setLevel(original_level)
  46. # Default model recommended in the original implementation.
  47. _DEFAULT_MODEL = "roberta-large"
  48. if _TRANSFORMERS_GREATER_EQUAL_4_4:
  49. from transformers import AutoModel, AutoTokenizer
  50. def _download_model_for_bert_score() -> None:
  51. """Download intensive operations."""
  52. with _ignore_log_warning():
  53. AutoTokenizer.from_pretrained(_DEFAULT_MODEL)
  54. AutoModel.from_pretrained(_DEFAULT_MODEL)
  55. if _SKIP_SLOW_DOCTEST and not _try_proceed_with_timeout(_download_model_for_bert_score):
  56. __doctest_skip__ = ["bert_score"]
  57. else:
  58. __doctest_skip__ = ["bert_score"]
  59. def _get_embeddings_and_idf_scale(
  60. dataloader: DataLoader,
  61. target_len: int,
  62. model: Module,
  63. device: Optional[Union[str, torch.device]] = None,
  64. num_layers: Optional[int] = None,
  65. all_layers: bool = False,
  66. idf: bool = False,
  67. verbose: bool = False,
  68. user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None,
  69. ) -> Tuple[Tensor, Tensor]:
  70. """Calculate sentence embeddings and the inverse-document-frequency scaling factor.
  71. Args:
  72. dataloader: dataloader instance.
  73. target_len: A length of the longest sequence in the data. Used for padding the model output.
  74. model: BERT model.
  75. device: A device to be used for calculation.
  76. num_layers: The layer of representation to use.
  77. all_layers: An indication whether representation from all model layers should be used for BERTScore.
  78. idf: An Indication whether normalization using inverse document frequencies should be used.
  79. verbose: An indication of whether a progress bar to be displayed during the embeddings' calculation.
  80. user_forward_fn:
  81. A user's own forward function used in a combination with ``user_model``. This function must
  82. take ``user_model`` and a python dictionary of containing ``"input_ids"`` and ``"attention_mask"``
  83. represented by :class:`~torch.Tensor` as an input and return the model's output represented by the single
  84. :class:`~torch.Tensor`.
  85. Return:
  86. A tuple of :class:`~torch.Tensor`s containing the model's embeddings and the normalized tokens IDF.
  87. When ``idf = False``, tokens IDF is not calculated, and a matrix of mean weights is returned instead.
  88. For a single sentence, ``mean_weight = 1/seq_len``, where ``seq_len`` is a sum over the corresponding
  89. ``attention_mask``.
  90. Raises:
  91. ValueError:
  92. If ``all_layers = True`` and a model, which is not from the ``transformers`` package, is used.
  93. """
  94. embeddings_list: List[Tensor] = []
  95. idf_scale_list: List[Tensor] = []
  96. for batch in _get_progress_bar(dataloader, verbose):
  97. with torch.no_grad():
  98. batch = _input_data_collator(batch, device)
  99. # Output shape: batch_size x num_layers OR 1 x sequence_length x bert_dim
  100. if not all_layers:
  101. if not user_forward_fn:
  102. out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True)
  103. out = out.hidden_states[num_layers if num_layers is not None else -1]
  104. else:
  105. out = user_forward_fn(model, batch)
  106. _check_shape_of_model_output(out, batch["input_ids"])
  107. out = out.unsqueeze(1)
  108. else:
  109. if user_forward_fn:
  110. raise ValueError(
  111. "The option `all_layers=True` can be used only with default `transformers` models."
  112. )
  113. out = model(batch["input_ids"], batch["attention_mask"], output_hidden_states=True)
  114. out = torch.cat([o.unsqueeze(1) for o in out.hidden_states], dim=1)
  115. out /= out.norm(dim=-1).unsqueeze(-1) # normalize embeddings
  116. out, attention_mask = _output_data_collator(out, batch["attention_mask"], target_len)
  117. processed_attention_mask = _process_attention_mask_for_special_tokens(attention_mask)
  118. # Multiply embeddings with attention_mask (b=batch_size, l=num_layers, s=seq_len, d=emb_dim)
  119. out = torch.einsum("blsd, bs -> blsd", out, processed_attention_mask)
  120. embeddings_list.append(out.cpu())
  121. # Calculate weighted (w.r.t. sentence length) input_ids IDF matrix
  122. input_ids_idf = (
  123. batch["input_ids_idf"] * processed_attention_mask if idf else processed_attention_mask.type(out.dtype)
  124. )
  125. input_ids_idf /= input_ids_idf.sum(-1, keepdim=True)
  126. idf_scale_list.append(input_ids_idf.cpu())
  127. embeddings = torch.cat(embeddings_list)
  128. idf_scale = torch.cat(idf_scale_list)
  129. return embeddings, idf_scale
  130. def _get_scaled_precision_or_recall(cos_sim: Tensor, metric: str, idf_scale: Tensor) -> Tensor:
  131. """Calculate precision or recall, transpose it and scale it with idf_scale factor."""
  132. dim = 3 if metric == "precision" else 2
  133. res = cos_sim.max(dim=dim).values
  134. res = torch.einsum("bls, bs -> bls", res, idf_scale).sum(-1)
  135. # We transpose the results and squeeze if possible to match the format of the original BERTScore implementation
  136. return res.transpose(0, 1).squeeze()
  137. def _get_precision_recall_f1(
  138. preds_embeddings: Tensor, target_embeddings: Tensor, preds_idf_scale: Tensor, target_idf_scale: Tensor
  139. ) -> Tuple[Tensor, Tensor, Tensor]:
  140. """Calculate precision, recall and F1 score over candidate and reference sentences.
  141. Args:
  142. preds_embeddings: Embeddings of candidate sentences.
  143. target_embeddings: Embeddings of reference sentences.
  144. preds_idf_scale: An IDF scale factor for candidate sentences.
  145. target_idf_scale: An IDF scale factor for reference sentences.
  146. Return:
  147. Tensors containing precision, recall and F1 score, respectively.
  148. """
  149. # Dimensions: b = batch_size, l = num_layers, p = predictions_seq_len, r = references_seq_len, d = bert_dim
  150. cos_sim = torch.einsum("blpd, blrd -> blpr", preds_embeddings, target_embeddings)
  151. # Final metrics shape = (batch_size * num_layers | batch_size)
  152. precision = _get_scaled_precision_or_recall(cos_sim, "precision", preds_idf_scale)
  153. recall = _get_scaled_precision_or_recall(cos_sim, "recall", target_idf_scale)
  154. f1_score = 2 * precision * recall / (precision + recall)
  155. f1_score = f1_score.masked_fill(torch.isnan(f1_score), 0.0)
  156. return precision, recall, f1_score
  157. def _get_hash(model_name_or_path: Optional[str] = None, num_layers: Optional[int] = None, idf: bool = False) -> str:
  158. """Compute `BERT_score`_ (copied and adjusted)."""
  159. return f"{model_name_or_path}_L{num_layers}{'_idf' if idf else '_no-idf'}"
  160. def _read_csv_from_local_file(baseline_path: str) -> Tensor:
  161. """Read baseline from csv file from the local file.
  162. This method implemented to avoid `pandas` dependency.
  163. """
  164. with open(baseline_path) as fname:
  165. csv_file = csv.reader(fname)
  166. baseline_list = [[float(item) for item in row] for idx, row in enumerate(csv_file) if idx > 0]
  167. return torch.tensor(baseline_list)[:, 1:]
  168. def _read_csv_from_url(baseline_url: str) -> Tensor:
  169. """Read baseline from csv file from URL.
  170. This method is implemented to avoid `pandas` dependency.
  171. """
  172. with urllib.request.urlopen(baseline_url) as http_request:
  173. baseline_list = [
  174. [float(item) for item in row.strip().decode("utf-8").split(",")]
  175. for idx, row in enumerate(http_request)
  176. if idx > 0
  177. ]
  178. return torch.tensor(baseline_list)[:, 1:]
  179. def _load_baseline(
  180. lang: str = "en",
  181. model_name_or_path: Optional[str] = None,
  182. baseline_path: Optional[str] = None,
  183. baseline_url: Optional[str] = None,
  184. ) -> Optional[Tensor]:
  185. """Load a CSV file with the baseline values used for rescaling."""
  186. if baseline_path:
  187. baseline: Optional[Tensor] = _read_csv_from_local_file(baseline_path)
  188. elif baseline_url:
  189. baseline = _read_csv_from_url(baseline_url)
  190. # Read default baseline from the original `bert-score` package https://github.com/Tiiiger/bert_score
  191. elif lang and model_name_or_path:
  192. url_base = "https://raw.githubusercontent.com/Tiiiger/bert_score/master/bert_score/rescale_baseline"
  193. baseline_url = f"{url_base}/{lang}/{model_name_or_path}.tsv"
  194. baseline = _read_csv_from_url(baseline_url)
  195. else:
  196. rank_zero_warn("Baseline was not successfully loaded. No baseline is going to be used.")
  197. return None
  198. return baseline
  199. def _rescale_metrics_with_baseline(
  200. precision: Tensor,
  201. recall: Tensor,
  202. f1_score: Tensor,
  203. baseline: Tensor,
  204. num_layers: Optional[int] = None,
  205. all_layers: bool = False,
  206. ) -> Tuple[Tensor, Tensor, Tensor]:
  207. """Rescale the computed metrics with the pre-computed baseline."""
  208. if num_layers is None and all_layers is False:
  209. num_layers = -1
  210. all_metrics = torch.stack([precision, recall, f1_score], dim=-1)
  211. baseline_scale = baseline.unsqueeze(1) if all_layers else baseline[num_layers]
  212. all_metrics = (all_metrics - baseline_scale) / (1 - baseline_scale)
  213. return all_metrics[..., 0], all_metrics[..., 1], all_metrics[..., 2]
  214. def _preprocess_multiple_references(
  215. preds: List[str], target: List[Union[str, Sequence[str]]]
  216. ) -> Tuple[List[str], List[str], Optional[List[Tuple[int, int]]]]:
  217. """Preprocesses predictions and targets when dealing with multiple references.
  218. This function handles the case where a single prediction might have multiple
  219. reference targets (represented as a list/tuple of strings).
  220. Args:
  221. preds: A list of predictions
  222. target: A list of targets, where each item could be a string or a list/tuple of strings
  223. Returns:
  224. Tuple: (preds, target, ref_group_boundaries)
  225. - preds: Flattened list of `str`
  226. - target: Flattened list of `str`
  227. - ref_group_boundaries: List of tuples (start, end) indicating the boundaries
  228. of reference groups in the flattened lists or `None`
  229. """
  230. if not all(isinstance(item, str) for item in preds):
  231. raise ValueError("Invalid input provided.")
  232. has_nested_sequences = any(isinstance(item, (list, tuple)) for item in target)
  233. if has_nested_sequences:
  234. ref_group_boundaries: List[Tuple[int, int]] = []
  235. new_preds: List[str] = []
  236. new_target: List[str] = []
  237. count = 0
  238. for pred, ref_group in zip(preds, target):
  239. if isinstance(ref_group, (list, tuple)):
  240. new_preds.extend([pred] * len(ref_group))
  241. new_target.extend(cast(List[str], ref_group))
  242. ref_group_boundaries.append((count, count + len(ref_group)))
  243. count += len(ref_group)
  244. else:
  245. new_preds.append(pred)
  246. new_target.append(cast(str, ref_group))
  247. ref_group_boundaries.append((count, count + 1))
  248. count += 1
  249. return new_preds, new_target, ref_group_boundaries
  250. return preds, cast(List[str], target), None
  251. def _postprocess_multiple_references(
  252. precision: Tensor, recall: Tensor, f1_score: Tensor, ref_group_boundaries: List[Tuple[int, int]]
  253. ) -> Tuple[Tensor, Tensor, Tensor]:
  254. """Postprocesses metrics when dealing with multiple references.
  255. For each group of references that correspond to a single prediction,
  256. this function takes the maximum score among all references.
  257. Args:
  258. precision: Tensor of precision scores
  259. recall: Tensor of recall scores
  260. f1_score: Tensor of F1 scores
  261. ref_group_boundaries: List of tuples (start, end) indicating the boundaries
  262. of reference groups
  263. Returns:
  264. tuple: (precision, recall, f1_score) with updated metrics
  265. """
  266. max_precision, max_recall, max_f1 = [], [], []
  267. for start, end in ref_group_boundaries:
  268. if precision.dim() > 1: # all_layers=True case
  269. max_precision.append(precision[:, start:end].max(dim=1)[0])
  270. max_recall.append(recall[:, start:end].max(dim=1)[0])
  271. max_f1.append(f1_score[:, start:end].max(dim=1)[0])
  272. else: # standard case
  273. max_precision.append(precision[start:end].max())
  274. max_recall.append(recall[start:end].max())
  275. max_f1.append(f1_score[start:end].max())
  276. if precision.dim() > 1:
  277. precision = torch.stack(max_precision, dim=1)
  278. recall = torch.stack(max_recall, dim=1)
  279. f1_score = torch.stack(max_f1, dim=1)
  280. else:
  281. precision = torch.stack(max_precision)
  282. recall = torch.stack(max_recall)
  283. f1_score = torch.stack(max_f1)
  284. return precision, recall, f1_score
  285. def bert_score(
  286. preds: Union[str, Sequence[str], dict[str, Tensor]],
  287. target: Union[str, Sequence[str], Sequence[Sequence[str]], dict[str, Tensor]],
  288. model_name_or_path: Optional[str] = None,
  289. num_layers: Optional[int] = None,
  290. all_layers: bool = False,
  291. model: Optional[Module] = None,
  292. user_tokenizer: Any = None,
  293. user_forward_fn: Optional[Callable[[Module, dict[str, Tensor]], Tensor]] = None,
  294. verbose: bool = False,
  295. idf: bool = False,
  296. device: Optional[Union[str, torch.device]] = None,
  297. max_length: int = 512,
  298. batch_size: int = 64,
  299. num_threads: int = 0,
  300. return_hash: bool = False,
  301. lang: str = "en",
  302. rescale_with_baseline: bool = False,
  303. baseline_path: Optional[str] = None,
  304. baseline_url: Optional[str] = None,
  305. truncation: bool = False,
  306. ) -> dict[str, Union[Tensor, List[float], str]]:
  307. """`Bert_score Evaluating Text Generation`_ for text similirity matching.
  308. This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference
  309. sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and
  310. system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for
  311. evaluating different language generation tasks.
  312. This implementation follows the original implementation from `BERT_score`_.
  313. Args:
  314. preds (Union[str, Sequence[str]]): A single predicted sentence or a sequence of predicted sentences.
  315. target (Union[str, Sequence[str], Sequence[Sequence[str]]]): A single target sentence, a sequence of target
  316. sentences, or a sequence of sequences of target sentences for multiple references per prediction.
  317. model_name_or_path: A name or a model path used to load ``transformers`` pretrained model.
  318. num_layers: A layer of representation to use.
  319. all_layers:
  320. An indication of whether the representation from all model's layers should be used.
  321. If ``all_layers = True``, the argument ``num_layers`` is ignored.
  322. model: A user's own model.
  323. user_tokenizer:
  324. A user's own tokenizer used with the own model. This must be an instance with the ``__call__`` method.
  325. This method must take an iterable of sentences (``List[str]``) and must return a python dictionary
  326. containing ``"input_ids"`` and ``"attention_mask"`` represented by :class:`~torch.Tensor`.
  327. It is up to the user's model of whether ``"input_ids"`` is a :class:`~torch.Tensor` of input ids
  328. or embedding vectors. his tokenizer must prepend an equivalent of ``[CLS]`` token and append an equivalent
  329. of ``[SEP]`` token as `transformers` tokenizer does.
  330. user_forward_fn:
  331. A user's own forward function used in a combination with ``user_model``.
  332. This function must take ``user_model`` and a python dictionary of containing ``"input_ids"``
  333. and ``"attention_mask"`` represented by :class:`~torch.Tensor` as an input and return the model's output
  334. represented by the single :class:`~torch.Tensor`.
  335. verbose: An indication of whether a progress bar to be displayed during the embeddings' calculation.
  336. idf: An indication of whether normalization using inverse document frequencies should be used.
  337. device: A device to be used for calculation.
  338. max_length: A maximum length of input sequences. Sequences longer than ``max_length`` are to be trimmed.
  339. batch_size: A batch size used for model processing.
  340. num_threads: A number of threads to use for a dataloader.
  341. return_hash: An indication of whether the correspodning ``hash_code`` should be returned.
  342. lang: A language of input sentences. It is used when the scores are rescaled with a baseline.
  343. rescale_with_baseline:
  344. An indication of whether bertscore should be rescaled with a pre-computed baseline.
  345. When a pretrained model from ``transformers`` model is used, the corresponding baseline is downloaded
  346. from the original ``bert-score`` package from `BERT_score`_ if available.
  347. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting
  348. of the files from `BERT_score`_
  349. baseline_path: A path to the user's own local csv/tsv file with the baseline scale.
  350. baseline_url: A url path to the user's own csv/tsv file with the baseline scale.
  351. truncation: An indication of whether the input sequences should be truncated to the maximum length.
  352. Returns:
  353. Python dictionary containing the keys ``precision``, ``recall`` and ``f1`` with corresponding values.
  354. Raises:
  355. ValueError:
  356. If ``len(preds) != len(target)``.
  357. ModuleNotFoundError:
  358. If `tqdm` package is required and not installed.
  359. ModuleNotFoundError:
  360. If ``transformers`` package is required and not installed.
  361. ValueError:
  362. If ``num_layer`` is larger than the number of the model layers.
  363. ValueError:
  364. If invalid input is provided.
  365. Example:
  366. >>> from pprint import pprint
  367. >>> from torchmetrics.functional.text.bert import bert_score
  368. >>> preds = ["hello there", "general kenobi"]
  369. >>> target = ["hello there", "master kenobi"]
  370. >>> pprint(bert_score(preds, target))
  371. {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
  372. Example:
  373. >>> from pprint import pprint
  374. >>> from torchmetrics.functional.text.bert import bert_score
  375. >>> preds = ["hello there", "general kenobi"]
  376. >>> target = [["hello there", "master kenobi"], ["hello there", "master kenobi"]]
  377. >>> pprint(bert_score(preds, target))
  378. {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
  379. """
  380. ref_group_boundaries: Optional[List[Tuple[int, int]]] = None
  381. if isinstance(preds, str):
  382. preds = [preds]
  383. if isinstance(target, str):
  384. target = [target]
  385. if not isinstance(preds, (list, dict)): # dict for BERTScore class compute call
  386. preds = list(preds)
  387. if not isinstance(target, (list, dict)): # dict for BERTScore class compute call
  388. target = list(target)
  389. if len(preds) != len(target):
  390. raise ValueError(
  391. "Expected number of predicted and reference sentences to be the same, but got"
  392. f"{len(preds)} and {len(target)}"
  393. )
  394. if isinstance(preds, list) and len(preds) > 0 and isinstance(target, list) and len(target) > 0:
  395. preds, target, ref_group_boundaries = _preprocess_multiple_references(preds, target)
  396. if not isinstance(idf, bool):
  397. raise ValueError(f"Expected argument `idf` to be a boolean, but got {idf}.")
  398. if verbose and (not _TQDM_AVAILABLE):
  399. raise ModuleNotFoundError(
  400. "An argument `verbose = True` requires `tqdm` package be installed. Install with `pip install tqdm`."
  401. )
  402. if model is None:
  403. if not _TRANSFORMERS_GREATER_EQUAL_4_4:
  404. raise ModuleNotFoundError(
  405. "`bert_score` metric with default models requires `transformers` package be installed."
  406. " Either install with `pip install transformers>=4.4` or `pip install torchmetrics[text]`."
  407. )
  408. if model_name_or_path is None:
  409. rank_zero_warn(
  410. "The argument `model_name_or_path` was not specified while it is required when default"
  411. " `transformers` model are used."
  412. f"It is, therefore, used the default recommended model - {_DEFAULT_MODEL}."
  413. )
  414. with _ignore_log_warning():
  415. tokenizer = AutoTokenizer.from_pretrained(model_name_or_path or _DEFAULT_MODEL)
  416. model = AutoModel.from_pretrained(model_name_or_path or _DEFAULT_MODEL)
  417. else:
  418. tokenizer = user_tokenizer
  419. model.eval()
  420. model.to(device)
  421. try:
  422. if hasattr(model.config, "num_hidden_layers") and isinstance(model.config.num_hidden_layers, int):
  423. if num_layers and num_layers > model.config.num_hidden_layers:
  424. raise ValueError(
  425. f"num_layers={num_layers} is forbidden for {model_name_or_path}."
  426. f" Please use num_layers <= {model.config.num_hidden_layers}"
  427. )
  428. else:
  429. rank_zero_warn(
  430. "Model config does not have `num_hidden_layers` as an integer attribute. "
  431. "Unable to validate `num_layers`."
  432. )
  433. except AttributeError:
  434. rank_zero_warn("It was not possible to retrieve the parameter `num_layers` from the model specification.")
  435. _are_empty_lists = all(isinstance(text, list) and len(text) == 0 for text in (preds, target))
  436. _are_valid_lists = all(
  437. isinstance(text, list) and len(text) > 0 and isinstance(text[0], str) for text in (preds, target)
  438. )
  439. _are_valid_tensors = all(
  440. isinstance(text, dict) and isinstance(text["input_ids"], Tensor) for text in (preds, target)
  441. )
  442. if _are_empty_lists:
  443. rank_zero_warn("Predictions and references are empty.")
  444. output_dict: dict[str, Union[Tensor, List[float], str]] = {
  445. "precision": [0.0],
  446. "recall": [0.0],
  447. "f1": [0.0],
  448. }
  449. if return_hash:
  450. output_dict.update({"hash": _get_hash(model_name_or_path, num_layers, idf)})
  451. return output_dict
  452. # Load baselines if needed
  453. baseline = _load_baseline(lang, model_name_or_path, baseline_path, baseline_url) if rescale_with_baseline else None
  454. # We ignore mypy typing below as the proper typing is ensured by conditions above, only mypy cannot infer that.
  455. if _are_valid_lists:
  456. target_dataset = TextDataset(target, tokenizer, max_length, idf=idf, truncation=truncation) # type: ignore
  457. preds_dataset = TextDataset(
  458. preds, # type: ignore
  459. tokenizer,
  460. max_length,
  461. idf=idf,
  462. tokens_idf=target_dataset.tokens_idf,
  463. truncation=truncation,
  464. )
  465. elif _are_valid_tensors:
  466. target_dataset = TokenizedDataset(**target, idf=idf) # type: ignore
  467. preds_dataset = TokenizedDataset(**preds, idf=idf, tokens_idf=target_dataset.tokens_idf) # type: ignore
  468. else:
  469. raise ValueError("Invalid input provided.")
  470. target_loader = DataLoader(target_dataset, batch_size=batch_size, num_workers=num_threads)
  471. preds_loader = DataLoader(preds_dataset, batch_size=batch_size, num_workers=num_threads)
  472. target_embeddings, target_idf_scale = _get_embeddings_and_idf_scale(
  473. target_loader, target_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn
  474. )
  475. preds_embeddings, preds_idf_scale = _get_embeddings_and_idf_scale(
  476. preds_loader, preds_dataset.max_length, model, device, num_layers, all_layers, idf, verbose, user_forward_fn
  477. )
  478. preds_embeddings = preds_embeddings[preds_loader.dataset.sorting_indices]
  479. target_embeddings = target_embeddings[target_loader.dataset.sorting_indices]
  480. preds_idf_scale = preds_idf_scale[preds_loader.dataset.sorting_indices]
  481. target_idf_scale = target_idf_scale[target_loader.dataset.sorting_indices]
  482. precision, recall, f1_score = _get_precision_recall_f1(
  483. preds_embeddings, target_embeddings, preds_idf_scale, target_idf_scale
  484. )
  485. if baseline is not None:
  486. precision, recall, f1_score = _rescale_metrics_with_baseline(
  487. precision, recall, f1_score, baseline, num_layers, all_layers
  488. )
  489. if ref_group_boundaries is not None:
  490. precision, recall, f1_score = _postprocess_multiple_references(
  491. precision, recall, f1_score, ref_group_boundaries
  492. )
  493. output_dict = {
  494. "precision": precision,
  495. "recall": recall,
  496. "f1": f1_score,
  497. }
  498. if return_hash:
  499. output_dict.update({"hash": _get_hash(model_name_or_path, num_layers, idf)})
  500. return output_dict