infolm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from collections.abc import Sequence
  16. from typing import Any, ClassVar, List, Optional, Union
  17. import torch
  18. from torch import Tensor
  19. from torchmetrics.functional.text.helper_embedding_metric import _load_tokenizer_and_model
  20. from torchmetrics.functional.text.infolm import (
  21. _ALLOWED_INFORMATION_MEASURE_LITERAL,
  22. _get_dataloader,
  23. _get_special_tokens_map,
  24. _infolm_compute,
  25. _infolm_update,
  26. _InformationMeasure,
  27. )
  28. from torchmetrics.metric import Metric
  29. from torchmetrics.utilities.data import dim_zero_cat
  30. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TRANSFORMERS_GREATER_EQUAL_4_4
  31. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  32. if not _MATPLOTLIB_AVAILABLE:
  33. __doctest_skip__ = ["InfoLM.plot"]
  34. if not _TRANSFORMERS_GREATER_EQUAL_4_4:
  35. __doctest_skip__ = ["InfoLM", "InfoLM.plot"]
  36. class InfoLM(Metric):
  37. """Calculate `InfoLM`_.
  38. InfoLM measures a distance/divergence between predicted and reference sentence discrete distribution using one of
  39. the following information measures:
  40. - `KL divergence`_
  41. - `alpha divergence`_
  42. - `beta divergence`_
  43. - `AB divergence`_
  44. - `Rényi divergence`_
  45. - L1 distance
  46. - L2 distance
  47. - L-infinity distance
  48. - `Fisher-Rao distance`_
  49. `InfoLM`_ is a family of untrained embedding-based metrics which addresses some famous flaws of standard
  50. string-based metrics thanks to the usage of pre-trained masked language models. This family of metrics is mainly
  51. designed for summarization and data-to-text tasks.
  52. The implementation of this metric is fully based HuggingFace ``transformers``' package.
  53. As input to ``forward`` and ``update`` the metric accepts the following input:
  54. - ``preds`` (:class:`~Sequence`): An iterable of hypothesis corpus
  55. - ``target`` (:class:`~Sequence`): An iterable of reference corpus
  56. As output of ``forward`` and ``compute`` the metric returns the following output:
  57. - ``infolm`` (:class:`~torch.Tensor`): If `return_sentence_level_score=True` return a tuple with a tensor
  58. with the corpus-level InfoLM score and a list of sentence-level InfoLM scores, else return a corpus-level
  59. InfoLM score
  60. Args:
  61. model_name_or_path:
  62. A name or a model path used to load ``transformers`` pretrained model.
  63. By default the `"bert-base-uncased"` model is used.
  64. temperature:
  65. A temperature for calibrating language modelling. For more information, please reference `InfoLM`_ paper.
  66. information_measure:
  67. A name of information measure to be used. Please use one of: ['kl_divergence', 'alpha_divergence',
  68. 'beta_divergence', 'ab_divergence', 'renyi_divergence', 'l1_distance', 'l2_distance', 'l_infinity_distance',
  69. 'fisher_rao_distance']
  70. idf:
  71. An indication of whether normalization using inverse document frequencies should be used.
  72. alpha:
  73. Alpha parameter of the divergence used for alpha, AB and Rényi divergence measures.
  74. beta:
  75. Beta parameter of the divergence used for beta and AB divergence measures.
  76. device:
  77. A device to be used for calculation.
  78. max_length:
  79. A maximum length of input sequences. Sequences longer than ``max_length`` are to be trimmed.
  80. batch_size:
  81. A batch size used for model processing.
  82. num_threads:
  83. A number of threads to use for a dataloader.
  84. verbose:
  85. An indication of whether a progress bar to be displayed during the embeddings calculation.
  86. return_sentence_level_score:
  87. An indication whether a sentence-level InfoLM score to be returned.
  88. Example:
  89. >>> from torchmetrics.text.infolm import InfoLM
  90. >>> preds = ['he read the book because he was interested in world history']
  91. >>> target = ['he was interested in world history because he read the book']
  92. >>> infolm = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
  93. >>> infolm(preds, target)
  94. tensor(-0.1784)
  95. """
  96. is_differentiable = False
  97. preds_input_ids: List[Tensor]
  98. preds_attention_mask: List[Tensor]
  99. target_input_ids: List[Tensor]
  100. target_attention_mask: List[Tensor]
  101. _information_measure_higher_is_better: ClassVar = {
  102. # following values are <0
  103. "kl_divergence": True,
  104. "alpha_divergence": True,
  105. # following values are >0
  106. "beta_divergence": False,
  107. "ab_divergence": False,
  108. "renyi_divergence": False,
  109. "l1_distance": False,
  110. "l2_distance": False,
  111. "l_infinity_distance": False,
  112. "fisher_rao_distance": False,
  113. }
  114. def __init__(
  115. self,
  116. model_name_or_path: Union[str, os.PathLike] = "bert-base-uncased",
  117. temperature: float = 0.25,
  118. information_measure: _ALLOWED_INFORMATION_MEASURE_LITERAL = "kl_divergence",
  119. idf: bool = True,
  120. alpha: Optional[float] = None,
  121. beta: Optional[float] = None,
  122. device: Optional[Union[str, torch.device]] = None,
  123. max_length: Optional[int] = None,
  124. batch_size: int = 64,
  125. num_threads: int = 0,
  126. verbose: bool = True,
  127. return_sentence_level_score: bool = False,
  128. **kwargs: dict[str, Any],
  129. ) -> None:
  130. super().__init__(**kwargs)
  131. self.model_name_or_path = model_name_or_path
  132. self.temperature = temperature
  133. self.information_measure = information_measure
  134. self.idf = idf
  135. self.alpha = alpha
  136. self.beta = beta
  137. self._device = torch.device(device or "cpu")
  138. self.batch_size = batch_size
  139. self.num_threads = num_threads
  140. self.verbose = verbose
  141. self.return_sentence_level_score = return_sentence_level_score
  142. self.tokenizer, self.model = _load_tokenizer_and_model(model_name_or_path, device)
  143. self.information_measure_cls = _InformationMeasure(information_measure, alpha, beta)
  144. self.max_length = max_length or self.model.config.max_length
  145. self.special_tokens_map = _get_special_tokens_map(self.tokenizer)
  146. self.add_state("preds_input_ids", [], dist_reduce_fx="cat")
  147. self.add_state("preds_attention_mask", [], dist_reduce_fx="cat")
  148. self.add_state("target_input_ids", [], dist_reduce_fx="cat")
  149. self.add_state("target_attention_mask", [], dist_reduce_fx="cat")
  150. @property
  151. def higher_is_better(self) -> bool: # type: ignore[override]
  152. """Returns a bool indicating whether a higher value of the information measure is better.
  153. Done this way as depends on if the information measure is positive or negative.
  154. """
  155. return self._information_measure_higher_is_better[self.information_measure]
  156. def update(self, preds: Union[str, Sequence[str]], target: Union[str, Sequence[str]]) -> None:
  157. """Update state with predictions and targets."""
  158. preds_input_ids, preds_attention_mask, target_input_ids, target_attention_mask = _infolm_update(
  159. preds, target, self.tokenizer, self.max_length
  160. )
  161. self.preds_input_ids.append(preds_input_ids)
  162. self.preds_attention_mask.append(preds_attention_mask)
  163. self.target_input_ids.append(target_input_ids)
  164. self.target_attention_mask.append(target_attention_mask)
  165. def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]:
  166. """Calculate selected information measure using the pre-trained language model."""
  167. preds_dataloader = _get_dataloader(
  168. input_ids=dim_zero_cat(self.preds_input_ids),
  169. attention_mask=dim_zero_cat(self.preds_attention_mask),
  170. idf=self.idf,
  171. batch_size=self.batch_size,
  172. num_workers=self.num_threads,
  173. )
  174. target_dataloader = _get_dataloader(
  175. input_ids=dim_zero_cat(self.target_input_ids),
  176. attention_mask=dim_zero_cat(self.target_attention_mask),
  177. idf=self.idf,
  178. batch_size=self.batch_size,
  179. num_workers=self.num_threads,
  180. )
  181. info_lm_score = _infolm_compute(
  182. self.model,
  183. preds_dataloader,
  184. target_dataloader,
  185. self.temperature,
  186. self.idf,
  187. self.information_measure_cls,
  188. self.special_tokens_map,
  189. self.verbose,
  190. )
  191. if self.return_sentence_level_score:
  192. return info_lm_score.mean(), info_lm_score
  193. return info_lm_score.mean()
  194. def plot(
  195. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  196. ) -> _PLOT_OUT_TYPE:
  197. """Plot a single or multiple values from the metric.
  198. Args:
  199. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  200. If no value is provided, will automatically call `metric.compute` and plot that result.
  201. ax: An matplotlib axis object. If provided will add plot to that axis
  202. Returns:
  203. Figure and Axes object
  204. Raises:
  205. ModuleNotFoundError:
  206. If `matplotlib` is not installed
  207. .. plot::
  208. :scale: 75
  209. >>> # Example plotting a single value
  210. >>> from torchmetrics.text.infolm import InfoLM
  211. >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
  212. >>> preds = ['he read the book because he was interested in world history']
  213. >>> target = ['he was interested in world history because he read the book']
  214. >>> metric.update(preds, target)
  215. >>> fig_, ax_ = metric.plot()
  216. .. plot::
  217. :scale: 75
  218. >>> # Example plotting multiple values
  219. >>> from torchmetrics.text.infolm import InfoLM
  220. >>> metric = InfoLM('google/bert_uncased_L-2_H-128_A-2', idf=False)
  221. >>> preds = ["this is the prediction", "there is an other sample"]
  222. >>> target = ["this is the reference", "there is another one"]
  223. >>> values = [ ]
  224. >>> for _ in range(10):
  225. ... values.append(metric(preds, target))
  226. >>> fig_, ax_ = metric.plot(values)
  227. """
  228. return self._plot(val, ax)