adaptive.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. from collections import namedtuple
  4. from collections.abc import Sequence
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import Tensor
  8. from .container import ModuleList, Sequential
  9. from .linear import Linear
  10. from .module import Module
  11. __all__ = ["AdaptiveLogSoftmaxWithLoss"]
  12. _ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
  13. class AdaptiveLogSoftmaxWithLoss(Module):
  14. (
  15. """Efficient softmax approximation.
  16. As described in
  17. `Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
  18. Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
  19. <https://arxiv.org/abs/1609.04309>`__.
  20. """
  21. r"""
  22. Adaptive softmax is an approximate strategy for training models with large
  23. output spaces. It is most effective when the label distribution is highly
  24. imbalanced, for example in natural language modelling, where the word
  25. frequency distribution approximately follows the `Zipf's law`_.
  26. Adaptive softmax partitions the labels into several clusters, according to
  27. their frequency. These clusters may contain different number of targets
  28. each.
  29. Additionally, clusters containing less frequent labels assign lower
  30. dimensional embeddings to those labels, which speeds up the computation.
  31. For each minibatch, only clusters for which at least one target is
  32. present are evaluated.
  33. The idea is that the clusters which are accessed frequently
  34. (like the first one, containing most frequent labels), should also be cheap
  35. to compute -- that is, contain a small number of assigned labels.
  36. We highly recommend taking a look at the original paper for more details.
  37. * :attr:`cutoffs` should be an ordered Sequence of integers sorted
  38. in the increasing order.
  39. It controls number of clusters and the partitioning of targets into
  40. clusters. For example setting ``cutoffs = [10, 100, 1000]``
  41. means that first `10` targets will be assigned
  42. to the 'head' of the adaptive softmax, targets `11, 12, ..., 100` will be
  43. assigned to the first cluster, and targets `101, 102, ..., 1000` will be
  44. assigned to the second cluster, while targets
  45. `1001, 1002, ..., n_classes - 1` will be assigned
  46. to the last, third cluster.
  47. * :attr:`div_value` is used to compute the size of each additional cluster,
  48. which is given as
  49. :math:`\left\lfloor\frac{\texttt{in\_features}}{\texttt{div\_value}^{idx}}\right\rfloor`,
  50. where :math:`idx` is the cluster index (with clusters
  51. for less frequent words having larger indices,
  52. and indices starting from :math:`1`).
  53. * :attr:`head_bias` if set to True, adds a bias term to the 'head' of the
  54. adaptive softmax. See paper for details. Set to False in the official
  55. implementation.
  56. .. warning::
  57. Labels passed as inputs to this module should be sorted according to
  58. their frequency. This means that the most frequent label should be
  59. represented by the index `0`, and the least frequent
  60. label should be represented by the index `n_classes - 1`.
  61. .. note::
  62. This module returns a ``NamedTuple`` with ``output``
  63. and ``loss`` fields. See further documentation for details.
  64. .. note::
  65. To compute log-probabilities for all classes, the ``log_prob``
  66. method can be used.
  67. Args:
  68. in_features (int): Number of features in the input tensor
  69. n_classes (int): Number of classes in the dataset
  70. cutoffs (Sequence): Cutoffs used to assign targets to their buckets
  71. div_value (float, optional): value used as an exponent to compute sizes
  72. of the clusters. Default: 4.0
  73. head_bias (bool, optional): If ``True``, adds a bias term to the 'head' of the
  74. adaptive softmax. Default: ``False``
  75. Returns:
  76. ``NamedTuple`` with ``output`` and ``loss`` fields:
  77. * **output** is a Tensor of size ``N`` containing computed target
  78. log probabilities for each example
  79. * **loss** is a Scalar representing the computed negative
  80. log likelihood loss
  81. Shape:
  82. - input: :math:`(N, \texttt{in\_features})` or :math:`(\texttt{in\_features})`
  83. - target: :math:`(N)` or :math:`()` where each value satisfies :math:`0 <= \texttt{target[i]} <= \texttt{n\_classes}`
  84. - output1: :math:`(N)` or :math:`()`
  85. - output2: ``Scalar``
  86. .. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
  87. """
  88. )
  89. in_features: int
  90. n_classes: int
  91. cutoffs: list[int]
  92. div_value: float
  93. head_bias: bool
  94. head: Linear
  95. tail: ModuleList
  96. def __init__(
  97. self,
  98. in_features: int,
  99. n_classes: int,
  100. cutoffs: Sequence[int],
  101. div_value: float = 4.0,
  102. head_bias: bool = False,
  103. device=None,
  104. dtype=None,
  105. ) -> None:
  106. factory_kwargs = {"device": device, "dtype": dtype}
  107. super().__init__()
  108. cutoffs = list(cutoffs)
  109. if len(cutoffs) == 0:
  110. raise ValueError("cutoffs should be a sequence of length larger than 0")
  111. if (
  112. (cutoffs != sorted(cutoffs))
  113. or (min(cutoffs) <= 0)
  114. or (max(cutoffs) > (n_classes - 1))
  115. or (len(set(cutoffs)) != len(cutoffs))
  116. or any(int(c) != c for c in cutoffs)
  117. ):
  118. raise ValueError(
  119. "cutoffs should be a sequence of unique, positive "
  120. "integers sorted in an increasing order, where "
  121. "each value is between 1 and n_classes-1"
  122. )
  123. self.in_features = in_features
  124. self.n_classes = n_classes
  125. self.cutoffs = cutoffs + [n_classes]
  126. self.div_value = div_value
  127. self.head_bias = head_bias
  128. self.shortlist_size = self.cutoffs[0]
  129. self.n_clusters = len(self.cutoffs) - 1
  130. self.head_size = self.shortlist_size + self.n_clusters
  131. self.head = Linear(
  132. self.in_features, self.head_size, bias=self.head_bias, **factory_kwargs
  133. )
  134. self.tail = ModuleList()
  135. for i in range(self.n_clusters):
  136. hsz = int(self.in_features // (self.div_value ** (i + 1)))
  137. osz = self.cutoffs[i + 1] - self.cutoffs[i]
  138. projection = Sequential(
  139. Linear(self.in_features, hsz, bias=False, **factory_kwargs),
  140. Linear(hsz, osz, bias=False, **factory_kwargs),
  141. )
  142. self.tail.append(projection)
  143. def reset_parameters(self) -> None:
  144. """
  145. Resets parameters based on their initialization used in ``__init__``.
  146. """
  147. self.head.reset_parameters()
  148. for i2h, h2o in self.tail: # type: ignore[misc]
  149. i2h.reset_parameters() # type: ignore[has-type]
  150. h2o.reset_parameters() # type: ignore[has-type]
  151. def forward(self, input_: Tensor, target_: Tensor) -> _ASMoutput:
  152. """
  153. Runs the forward pass.
  154. """
  155. targ_dim = target_.dim()
  156. if targ_dim == 1:
  157. if input_.size(0) != target_.size(0):
  158. raise RuntimeError(
  159. "Input and target should have the same size in the batch dimension."
  160. )
  161. if input_.dim() != 2:
  162. raise RuntimeError(
  163. "1D target tensor expects 2D input tensors, "
  164. "but found inputs with size",
  165. input_.size(),
  166. )
  167. elif targ_dim == 0:
  168. if input_.dim() != 1:
  169. raise RuntimeError(
  170. "0D target tensor expects 1D input tensors, "
  171. "but found inputs with size",
  172. input_.size(),
  173. )
  174. else:
  175. raise RuntimeError(
  176. "0D or 1D target tensor expected, multi-target not supported"
  177. )
  178. is_batched = targ_dim > 0
  179. input = input_ if is_batched else input_.unsqueeze(0)
  180. target = target_ if is_batched else target_.unsqueeze(0)
  181. used_rows = 0
  182. batch_size = target.size(0)
  183. output = input.new_zeros(batch_size)
  184. gather_inds = target.new_empty(batch_size)
  185. cutoff_values = [0] + self.cutoffs
  186. for i in range(len(cutoff_values) - 1):
  187. low_idx = cutoff_values[i]
  188. high_idx = cutoff_values[i + 1]
  189. target_mask = (target >= low_idx) & (target < high_idx)
  190. row_indices = target_mask.nonzero().squeeze()
  191. if row_indices.numel() == 0:
  192. continue
  193. if i == 0:
  194. gather_inds.index_copy_(0, row_indices, target[target_mask])
  195. else:
  196. relative_target = target[target_mask] - low_idx
  197. input_subset = input.index_select(0, row_indices)
  198. cluster_output = self.tail[i - 1](input_subset)
  199. cluster_index = self.shortlist_size + i - 1
  200. gather_inds.index_fill_(0, row_indices, cluster_index)
  201. cluster_logprob = F.log_softmax(cluster_output, dim=1)
  202. local_logprob = cluster_logprob.gather(1, relative_target.unsqueeze(1))
  203. output.index_copy_(0, row_indices, local_logprob.squeeze(1))
  204. used_rows += row_indices.numel()
  205. if used_rows != batch_size:
  206. raise RuntimeError(
  207. f"Target values should be in [0, {self.n_classes - 1}], "
  208. f"but values in range [{target.min().item()}, {target.max().item()}] "
  209. "were found. "
  210. )
  211. head_output = self.head(input)
  212. head_logprob = F.log_softmax(head_output, dim=1)
  213. output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze()
  214. loss = (-output).mean()
  215. if not is_batched:
  216. output = output.squeeze(0)
  217. return _ASMoutput(output, loss)
  218. def _get_full_log_prob(self, input, head_output):
  219. """Given input tensor, and output of ``self.head``, compute the log of the full distribution."""
  220. out = input.new_empty((head_output.size(0), self.n_classes))
  221. head_logprob = F.log_softmax(head_output, dim=1)
  222. out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
  223. for i, (start_idx, stop_idx) in enumerate(itertools.pairwise(self.cutoffs)):
  224. cluster_output = self.tail[i](input)
  225. cluster_logprob = F.log_softmax(cluster_output, dim=1)
  226. output_logprob = cluster_logprob + head_logprob[
  227. :, self.shortlist_size + i
  228. ].unsqueeze(1)
  229. out[:, start_idx:stop_idx] = output_logprob
  230. return out
  231. def log_prob(self, input: Tensor) -> Tensor:
  232. r"""Compute log probabilities for all :math:`\texttt{n\_classes}`.
  233. Args:
  234. input (Tensor): a minibatch of examples
  235. Returns:
  236. log-probabilities of for each class :math:`c`
  237. in range :math:`0 <= c <= \texttt{n\_classes}`, where :math:`\texttt{n\_classes}` is a
  238. parameter passed to ``AdaptiveLogSoftmaxWithLoss`` constructor.
  239. Shape:
  240. - Input: :math:`(N, \texttt{in\_features})`
  241. - Output: :math:`(N, \texttt{n\_classes})`
  242. """
  243. head_output = self.head(input)
  244. return self._get_full_log_prob(input, head_output)
  245. def predict(self, input: Tensor) -> Tensor:
  246. r"""Return the class with the highest probability for each example in the input minibatch.
  247. This is equivalent to ``self.log_prob(input).argmax(dim=1)``, but is more efficient in some cases.
  248. Args:
  249. input (Tensor): a minibatch of examples
  250. Returns:
  251. output (Tensor): a class with the highest probability for each example
  252. Shape:
  253. - Input: :math:`(N, \texttt{in\_features})`
  254. - Output: :math:`(N)`
  255. """
  256. head_output = self.head(input)
  257. output = torch.argmax(head_output, dim=1)
  258. not_in_shortlist = output >= self.shortlist_size
  259. all_in_shortlist = not (not_in_shortlist.any())
  260. if all_in_shortlist:
  261. return output
  262. elif not_in_shortlist.all():
  263. log_prob = self._get_full_log_prob(input, head_output)
  264. return torch.argmax(log_prob, dim=1)
  265. else:
  266. log_prob = self._get_full_log_prob(
  267. input[not_in_shortlist], head_output[not_in_shortlist]
  268. )
  269. output[not_in_shortlist] = torch.argmax(log_prob, dim=1)
  270. return output