perplexity.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. def _check_shape_and_type_consistency(preds: Tensor, target: Tensor) -> None:
  18. """Check shape and type consistency of input vectors.
  19. Args:
  20. preds:
  21. Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len,
  22. vocab_size]. Scores will be normalized internally using softmax.
  23. target:
  24. Ground truth values with a shape [batch_size, seq_len].
  25. Raises:
  26. ValueError:
  27. If ``preds`` tensor has no 3 dimensions.
  28. ValueError:
  29. If ``target`` tensor has no 2 dimensions.
  30. ValueError:
  31. If the first two dimensions of ``preds`` and ``target`` do not equal.
  32. TypeError:
  33. If ``preds`` dtype is not one of ``(torch.float16, torch.float32, torch.float64)``
  34. TypeError:
  35. If ``target`` is not of a type LongTensor (torch.int64)
  36. """
  37. if len(preds.shape) != 3:
  38. raise ValueError(
  39. "Input tensor `preds` is expected to have 3 dimensions, [batch_size, seq_len, vocab_size],"
  40. f" but got {len(preds.shape)}."
  41. )
  42. if len(target.shape) != 2:
  43. raise ValueError(
  44. "Input tensor `target` is expected to have 2 dimensions, [batch_size, seq_len],"
  45. f" but got {len(target.shape)}."
  46. )
  47. if preds.shape[:2] != target.shape:
  48. raise ValueError(
  49. "Input tensors `preds` and `target` are expected to have equaling first two dimensions,"
  50. f" [batch_size, seq_len], but got {preds.shape[:2]} and {target.shape}."
  51. )
  52. if not preds.is_floating_point():
  53. raise TypeError(f"Input tensor `preds` is expected to be of floating point type but got {preds.dtype}.")
  54. if target.dtype != torch.int64:
  55. raise TypeError(f"Input tensor `target` is expected to be of a type {torch.int64} but got {target.dtype}.")
  56. def _perplexity_update(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> tuple[Tensor, Tensor]:
  57. """Compute intermediate statistics for Perplexity.
  58. Args:
  59. preds:
  60. Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len,
  61. vocab_size]. Scores will be normalized internally using softmax.
  62. target:
  63. Ground truth values with a shape [batch_size, seq_len].
  64. ignore_index:
  65. Integer specifying a target class to ignore. If given, this class index does not contribute
  66. to the returned score.
  67. Returns:
  68. Log probabilities, summed over all samples
  69. Number of samples
  70. """
  71. _check_shape_and_type_consistency(preds, target)
  72. probs = torch.nn.functional.softmax(preds.reshape(-1, preds.shape[-1]), dim=1)
  73. target = target.reshape(-1)
  74. if ignore_index is not None:
  75. mask = target.ne(ignore_index)
  76. target = target.where(target != ignore_index, torch.tensor(0, device=target.device))
  77. else:
  78. mask = torch.ones_like(target, dtype=torch.bool)
  79. probs = probs[torch.arange(target.numel()), target][mask]
  80. total_log_probs = -probs.log().sum()
  81. count = mask.sum()
  82. return total_log_probs, count
  83. def _perplexity_compute(total: Tensor, count: Tensor) -> Tensor:
  84. """Compute the Perplexity.
  85. Args:
  86. total: Log probabilities, summed over all samples
  87. count: Number of samples
  88. Returns:
  89. Perplexity
  90. """
  91. return torch.exp(total / count)
  92. def perplexity(preds: Tensor, target: Tensor, ignore_index: Optional[int] = None) -> Tensor:
  93. """Perplexity measures how well a language model predicts a text sample.
  94. This metric is calculated as the average number of bits per word a model needs to represent the sample.
  95. Args:
  96. preds:
  97. Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len,
  98. vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.
  99. target:
  100. Ground truth values with a shape [batch_size, seq_len].
  101. ignore_index:
  102. Integer specifying a target class to ignore. If given, this class index does not contribute
  103. to the returned score.
  104. Returns:
  105. Perplexity value
  106. Examples:
  107. >>> from torch import rand, randint
  108. >>> preds = rand(2, 8, 5)
  109. >>> target = randint(5, (2, 8))
  110. >>> target[0, 6:] = -100
  111. >>> perplexity(preds, target, ignore_index=-100)
  112. tensor(5.8540)
  113. """
  114. total, count = _perplexity_update(preds, target, ignore_index)
  115. return _perplexity_compute(total, count)