loss.py 95 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124
  1. # mypy: allow-untyped-defs
  2. from collections.abc import Callable
  3. from typing_extensions import deprecated
  4. from torch import Tensor
  5. from torch.nn import _reduction as _Reduction, functional as F
  6. from .distance import PairwiseDistance
  7. from .module import Module
  8. __all__ = [
  9. "L1Loss",
  10. "NLLLoss",
  11. "NLLLoss2d",
  12. "PoissonNLLLoss",
  13. "GaussianNLLLoss",
  14. "KLDivLoss",
  15. "MSELoss",
  16. "BCELoss",
  17. "BCEWithLogitsLoss",
  18. "HingeEmbeddingLoss",
  19. "MultiLabelMarginLoss",
  20. "SmoothL1Loss",
  21. "HuberLoss",
  22. "SoftMarginLoss",
  23. "CrossEntropyLoss",
  24. "MultiLabelSoftMarginLoss",
  25. "CosineEmbeddingLoss",
  26. "MarginRankingLoss",
  27. "MultiMarginLoss",
  28. "TripletMarginLoss",
  29. "TripletMarginWithDistanceLoss",
  30. "CTCLoss",
  31. ]
  32. class _Loss(Module):
  33. reduction: str
  34. def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
  35. super().__init__()
  36. if size_average is not None or reduce is not None:
  37. self.reduction: str = _Reduction.legacy_get_string(size_average, reduce)
  38. else:
  39. self.reduction = reduction
  40. class _WeightedLoss(_Loss):
  41. def __init__(
  42. self,
  43. weight: Tensor | None = None,
  44. size_average=None,
  45. reduce=None,
  46. reduction: str = "mean",
  47. ) -> None:
  48. super().__init__(size_average, reduce, reduction)
  49. self.register_buffer("weight", weight)
  50. self.weight: Tensor | None
  51. class L1Loss(_Loss):
  52. r"""Creates a criterion that measures the mean absolute error (MAE) between each element in
  53. the input :math:`x` and target :math:`y`.
  54. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
  55. .. math::
  56. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
  57. l_n = \left| x_n - y_n \right|,
  58. where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
  59. (default ``'mean'``), then:
  60. .. math::
  61. \ell(x, y) =
  62. \begin{cases}
  63. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  64. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  65. \end{cases}
  66. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  67. of :math:`N` elements each.
  68. The sum operation still operates over all the elements, and divides by :math:`N`.
  69. The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``.
  70. Supports real-valued and complex-valued inputs.
  71. Args:
  72. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  73. the losses are averaged over each loss element in the batch. Note that for
  74. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  75. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  76. when :attr:`reduce` is ``False``. Default: ``True``
  77. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  78. losses are averaged or summed over observations for each minibatch depending
  79. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  80. batch element instead and ignores :attr:`size_average`. Default: ``True``
  81. reduction (str, optional): Specifies the reduction to apply to the output:
  82. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  83. ``'mean'``: the sum of the output will be divided by the number of
  84. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  85. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  86. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  87. Shape:
  88. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  89. - Target: :math:`(*)`, same shape as the input.
  90. - Output: scalar. If :attr:`reduction` is ``'none'``, then
  91. :math:`(*)`, same shape as the input.
  92. Examples:
  93. >>> loss = nn.L1Loss()
  94. >>> input = torch.randn(3, 5, requires_grad=True)
  95. >>> target = torch.randn(3, 5)
  96. >>> output = loss(input, target)
  97. >>> output.backward()
  98. """
  99. __constants__ = ["reduction"]
  100. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  101. """
  102. Runs the forward pass.
  103. """
  104. return F.l1_loss(input, target, reduction=self.reduction)
  105. class NLLLoss(_WeightedLoss):
  106. r"""The negative log likelihood loss. It is useful to train a classification
  107. problem with `C` classes.
  108. If provided, the optional argument :attr:`weight` should be a 1D Tensor assigning
  109. weight to each of the classes. This is particularly useful when you have an
  110. unbalanced training set.
  111. The `input` given through a forward call is expected to contain
  112. log-probabilities of each class. `input` has to be a Tensor of size either
  113. :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)`
  114. with :math:`K \geq 1` for the `K`-dimensional case. The latter is useful for
  115. higher dimension inputs, such as computing NLL loss per-pixel for 2D images.
  116. Obtaining log-probabilities in a neural network is easily achieved by
  117. adding a `LogSoftmax` layer in the last layer of your network.
  118. You may use `CrossEntropyLoss` instead, if you prefer not to add an extra
  119. layer.
  120. The `target` that this loss expects should be a class index in the range :math:`[0, C-1]`
  121. where `C = number of classes`; if `ignore_index` is specified, this loss also accepts
  122. this class index (this index may not necessarily be in the class range).
  123. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
  124. .. math::
  125. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \\
  126. l_n = - w_{y_n} x_{n,y_n}, \\
  127. w_{c} = \text{weight}[c] \cdot \mathbb{1}\{c \not= \text{ignore\_index}\},
  128. where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, and
  129. :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
  130. (default ``'mean'``), then
  131. .. math::
  132. \ell(x, y) = \begin{cases}
  133. \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n}} l_n, &
  134. \text{if reduction} = \text{`mean';}\\
  135. \sum_{n=1}^N l_n, &
  136. \text{if reduction} = \text{`sum'.}
  137. \end{cases}
  138. Args:
  139. weight (Tensor, optional): a manual rescaling weight given to each
  140. class. If given, it has to be a Tensor of size `C`. Otherwise, it is
  141. treated as if having all ones.
  142. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  143. the losses are averaged over each loss element in the batch. Note that for
  144. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  145. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  146. when :attr:`reduce` is ``False``. Default: ``None``
  147. ignore_index (int, optional): Specifies a target value that is ignored
  148. and does not contribute to the input gradient. When
  149. :attr:`size_average` is ``True``, the loss is averaged over
  150. non-ignored targets.
  151. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  152. losses are averaged or summed over observations for each minibatch depending
  153. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  154. batch element instead and ignores :attr:`size_average`. Default: ``None``
  155. reduction (str, optional): Specifies the reduction to apply to the output:
  156. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
  157. be applied, ``'mean'``: the weighted mean of the output is taken,
  158. ``'sum'``: the output will be summed. Note: :attr:`size_average`
  159. and :attr:`reduce` are in the process of being deprecated, and in
  160. the meantime, specifying either of those two args will override
  161. :attr:`reduction`. Default: ``'mean'``
  162. Shape::
  163. - Input: :math:`(N, C)` or :math:`(C)`, where `C = number of classes`, `N = batch size`, or
  164. :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
  165. in the case of `K`-dimensional loss.
  166. - Target: :math:`(N)` or :math:`()`, where each value is
  167. :math:`0 \leq \text{targets}[i] \leq C-1`, or
  168. :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
  169. K-dimensional loss.
  170. - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)` or
  171. :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss.
  172. Otherwise, scalar.
  173. Examples:
  174. >>> log_softmax = nn.LogSoftmax(dim=1)
  175. >>> loss_fn = nn.NLLLoss()
  176. >>> # input to NLLLoss is of size N x C = 3 x 5
  177. >>> input = torch.randn(3, 5, requires_grad=True)
  178. >>> # each element in target must have 0 <= value < C
  179. >>> target = torch.tensor([1, 0, 4])
  180. >>> loss = loss_fn(log_softmax(input), target)
  181. >>> loss.backward()
  182. >>>
  183. >>>
  184. >>> # 2D loss example (used, for example, with image inputs)
  185. >>> N, C = 5, 4
  186. >>> loss_fn = nn.NLLLoss()
  187. >>> data = torch.randn(N, 16, 10, 10)
  188. >>> conv = nn.Conv2d(16, C, (3, 3))
  189. >>> log_softmax = nn.LogSoftmax(dim=1)
  190. >>> # output of conv forward is of shape [N, C, 8, 8]
  191. >>> output = log_softmax(conv(data))
  192. >>> # each element in target must have 0 <= value < C
  193. >>> target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
  194. >>> # input to NLLLoss is of size N x C x height (8) x width (8)
  195. >>> loss = loss_fn(output, target)
  196. >>> loss.backward()
  197. """
  198. __constants__ = ["ignore_index", "reduction"]
  199. ignore_index: int
  200. def __init__(
  201. self,
  202. weight: Tensor | None = None,
  203. size_average=None,
  204. ignore_index: int = -100,
  205. reduce=None,
  206. reduction: str = "mean",
  207. ) -> None:
  208. super().__init__(weight, size_average, reduce, reduction)
  209. self.ignore_index = ignore_index
  210. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  211. """
  212. Runs the forward pass.
  213. """
  214. return F.nll_loss(
  215. input,
  216. target,
  217. weight=self.weight,
  218. ignore_index=self.ignore_index,
  219. reduction=self.reduction,
  220. )
  221. @deprecated(
  222. "`NLLLoss2d` has been deprecated. "
  223. "Please use `NLLLoss` instead as a drop-in replacement and see "
  224. "https://pytorch.org/docs/main/nn.html#torch.nn.NLLLoss for more details.",
  225. category=FutureWarning,
  226. )
  227. class NLLLoss2d(NLLLoss):
  228. def __init__(
  229. self,
  230. weight: Tensor | None = None,
  231. size_average=None,
  232. ignore_index: int = -100,
  233. reduce=None,
  234. reduction: str = "mean",
  235. ) -> None:
  236. super().__init__(weight, size_average, ignore_index, reduce, reduction)
  237. class PoissonNLLLoss(_Loss):
  238. r"""Negative log likelihood loss with Poisson distribution of target.
  239. The loss can be described as:
  240. .. math::
  241. \text{target} \sim \mathrm{Poisson}(\text{input})
  242. \text{loss}(\text{input}, \text{target}) = \text{input} - \text{target} * \log(\text{input})
  243. + \log(\text{target!})
  244. The last term can be omitted or approximated with Stirling formula. The
  245. approximation is used for target values more than 1. For targets less or
  246. equal to 1 zeros are added to the loss.
  247. Args:
  248. log_input (bool, optional): if ``True`` the loss is computed as
  249. :math:`\exp(\text{input}) - \text{target}*\text{input}`, if ``False`` the loss is
  250. :math:`\text{input} - \text{target}*\log(\text{input}+\text{eps})`.
  251. full (bool, optional): whether to compute full loss, i. e. to add the
  252. Stirling approximation term
  253. .. math::
  254. \text{target}*\log(\text{target}) - \text{target} + 0.5 * \log(2\pi\text{target}).
  255. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  256. the losses are averaged over each loss element in the batch. Note that for
  257. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  258. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  259. when :attr:`reduce` is ``False``. Default: ``True``
  260. eps (float, optional): Small value to avoid evaluation of :math:`\log(0)` when
  261. :attr:`log_input = False`. Default: 1e-8
  262. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  263. losses are averaged or summed over observations for each minibatch depending
  264. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  265. batch element instead and ignores :attr:`size_average`. Default: ``True``
  266. reduction (str, optional): Specifies the reduction to apply to the output:
  267. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  268. ``'mean'``: the sum of the output will be divided by the number of
  269. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  270. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  271. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  272. Examples:
  273. >>> loss = nn.PoissonNLLLoss()
  274. >>> log_input = torch.randn(5, 2, requires_grad=True)
  275. >>> target = torch.randn(5, 2)
  276. >>> output = loss(log_input, target)
  277. >>> output.backward()
  278. Shape:
  279. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  280. - Target: :math:`(*)`, same shape as the input.
  281. - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`,
  282. the same shape as the input.
  283. """
  284. __constants__ = ["log_input", "full", "eps", "reduction"]
  285. log_input: bool
  286. full: bool
  287. eps: float
  288. def __init__(
  289. self,
  290. log_input: bool = True,
  291. full: bool = False,
  292. size_average=None,
  293. eps: float = 1e-8,
  294. reduce=None,
  295. reduction: str = "mean",
  296. ) -> None:
  297. super().__init__(size_average, reduce, reduction)
  298. self.log_input = log_input
  299. self.full = full
  300. self.eps = eps
  301. def forward(self, log_input: Tensor, target: Tensor) -> Tensor:
  302. """
  303. Runs the forward pass.
  304. """
  305. return F.poisson_nll_loss(
  306. log_input,
  307. target,
  308. log_input=self.log_input,
  309. full=self.full,
  310. eps=self.eps,
  311. reduction=self.reduction,
  312. )
  313. class GaussianNLLLoss(_Loss):
  314. r"""Gaussian negative log likelihood loss.
  315. The targets are treated as samples from Gaussian distributions with
  316. expectations and variances predicted by the neural network. For a
  317. ``target`` tensor modelled as having Gaussian distribution with a tensor
  318. of expectations ``input`` and a tensor of positive variances ``var`` the loss is:
  319. .. math::
  320. \text{loss} = \frac{1}{2}\left(\log\left(\text{max}\left(\text{var},
  321. \ \text{eps}\right)\right) + \frac{\left(\text{input} - \text{target}\right)^2}
  322. {\text{max}\left(\text{var}, \ \text{eps}\right)}\right) + \text{const.}
  323. where :attr:`eps` is used for stability. By default, the constant term of
  324. the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
  325. size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
  326. of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
  327. Args:
  328. full (bool, optional): include the constant term in the loss
  329. calculation. Default: ``False``.
  330. eps (float, optional): value used to clamp ``var`` (see note below), for
  331. stability. Default: 1e-6.
  332. reduction (str, optional): specifies the reduction to apply to the
  333. output:``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
  334. will be applied, ``'mean'``: the output is the average of all batch
  335. member losses, ``'sum'``: the output is the sum of all batch member
  336. losses. Default: ``'mean'``.
  337. Shape:
  338. - Input: :math:`(N, *)` or :math:`(*)` where :math:`*` means any number of additional
  339. dimensions
  340. - Target: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input
  341. but with one dimension equal to 1 (to allow for broadcasting)
  342. - Var: :math:`(N, *)` or :math:`(*)`, same shape as the input, or same shape as the input but
  343. with one dimension equal to 1, or same shape as the input but with one fewer
  344. dimension (to allow for broadcasting), or a scalar value
  345. - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
  346. ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N, *)`, same
  347. shape as the input
  348. Examples:
  349. >>> loss = nn.GaussianNLLLoss()
  350. >>> input = torch.randn(5, 2, requires_grad=True)
  351. >>> target = torch.randn(5, 2)
  352. >>> var = torch.ones(5, 2, requires_grad=True) # heteroscedastic
  353. >>> output = loss(input, target, var)
  354. >>> output.backward()
  355. >>> loss = nn.GaussianNLLLoss()
  356. >>> input = torch.randn(5, 2, requires_grad=True)
  357. >>> target = torch.randn(5, 2)
  358. >>> var = torch.ones(5, 1, requires_grad=True) # homoscedastic
  359. >>> output = loss(input, target, var)
  360. >>> output.backward()
  361. Note:
  362. The clamping of ``var`` is ignored with respect to autograd, and so the
  363. gradients are unaffected by it.
  364. Reference:
  365. Nix, D. A. and Weigend, A. S., "Estimating the mean and variance of the
  366. target probability distribution", Proceedings of 1994 IEEE International
  367. Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
  368. vol.1, doi: 10.1109/ICNN.1994.374138.
  369. """
  370. __constants__ = ["full", "eps", "reduction"]
  371. full: bool
  372. eps: float
  373. def __init__(
  374. self, *, full: bool = False, eps: float = 1e-6, reduction: str = "mean"
  375. ) -> None:
  376. super().__init__(None, None, reduction)
  377. self.full = full
  378. self.eps = eps
  379. def forward(self, input: Tensor, target: Tensor, var: Tensor | float) -> Tensor:
  380. """
  381. Runs the forward pass.
  382. """
  383. return F.gaussian_nll_loss(
  384. input, target, var, full=self.full, eps=self.eps, reduction=self.reduction
  385. )
  386. class KLDivLoss(_Loss):
  387. r"""The Kullback-Leibler divergence loss.
  388. For tensors of the same shape :math:`y_{\text{pred}},\ y_{\text{true}}`,
  389. where :math:`y_{\text{pred}}` is the :attr:`input` and :math:`y_{\text{true}}` is the
  390. :attr:`target`, we define the **pointwise KL-divergence** as
  391. .. math::
  392. L(y_{\text{pred}},\ y_{\text{true}})
  393. = y_{\text{true}} \cdot \log \frac{y_{\text{true}}}{y_{\text{pred}}}
  394. = y_{\text{true}} \cdot (\log y_{\text{true}} - \log y_{\text{pred}})
  395. To avoid underflow issues when computing this quantity, this loss expects the argument
  396. :attr:`input` in the log-space. The argument :attr:`target` may also be provided in the
  397. log-space if :attr:`log_target`\ `= True`.
  398. To summarise, this function is roughly equivalent to computing
  399. .. code-block:: python
  400. if not log_target: # default
  401. loss_pointwise = target * (target.log() - input)
  402. else:
  403. loss_pointwise = target.exp() * (target - input)
  404. and then reducing this result depending on the argument :attr:`reduction` as
  405. .. code-block:: python
  406. if reduction == "mean": # default
  407. loss = loss_pointwise.mean()
  408. elif reduction == "batchmean": # mathematically correct
  409. loss = loss_pointwise.sum() / input.size(0)
  410. elif reduction == "sum":
  411. loss = loss_pointwise.sum()
  412. else: # reduction == "none"
  413. loss = loss_pointwise
  414. .. note::
  415. As all the other losses in PyTorch, this function expects the first argument,
  416. :attr:`input`, to be the output of the model (e.g. the neural network)
  417. and the second, :attr:`target`, to be the observations in the dataset.
  418. This differs from the standard mathematical notation :math:`KL(P\ ||\ Q)` where
  419. :math:`P` denotes the distribution of the observations and :math:`Q` denotes the model.
  420. .. warning::
  421. :attr:`reduction`\ `= "mean"` doesn't return the true KL divergence value, please use
  422. :attr:`reduction`\ `= "batchmean"` which aligns with the mathematical definition.
  423. Args:
  424. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  425. the losses are averaged over each loss element in the batch. Note that for
  426. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  427. is set to `False`, the losses are instead summed for each minibatch. Ignored
  428. when :attr:`reduce` is `False`. Default: `True`
  429. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  430. losses are averaged or summed over observations for each minibatch depending
  431. on :attr:`size_average`. When :attr:`reduce` is `False`, returns a loss per
  432. batch element instead and ignores :attr:`size_average`. Default: `True`
  433. reduction (str, optional): Specifies the reduction to apply to the output. Default: `"mean"`
  434. log_target (bool, optional): Specifies whether `target` is the log space. Default: `False`
  435. Shape:
  436. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  437. - Target: :math:`(*)`, same shape as the input.
  438. - Output: scalar by default. If :attr:`reduction` is `'none'`, then :math:`(*)`,
  439. same shape as the input.
  440. Examples:
  441. >>> kl_loss = nn.KLDivLoss(reduction="batchmean")
  442. >>> # input should be a distribution in the log space
  443. >>> input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
  444. >>> # Sample a batch of distributions. Usually this would come from the dataset
  445. >>> target = F.softmax(torch.rand(3, 5), dim=1)
  446. >>> output = kl_loss(input, target)
  447. >>>
  448. >>> kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)
  449. >>> log_target = F.log_softmax(torch.rand(3, 5), dim=1)
  450. >>> output = kl_loss(input, log_target)
  451. """
  452. __constants__ = ["reduction"]
  453. def __init__(
  454. self,
  455. size_average=None,
  456. reduce=None,
  457. reduction: str = "mean",
  458. log_target: bool = False,
  459. ) -> None:
  460. super().__init__(size_average, reduce, reduction)
  461. self.log_target = log_target
  462. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  463. """
  464. Runs the forward pass.
  465. """
  466. return F.kl_div(
  467. input, target, reduction=self.reduction, log_target=self.log_target
  468. )
  469. class MSELoss(_Loss):
  470. r"""Creates a criterion that measures the mean squared error (squared L2 norm) between
  471. each element in the input :math:`x` and target :math:`y`.
  472. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
  473. .. math::
  474. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
  475. l_n = \left( x_n - y_n \right)^2,
  476. where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
  477. (default ``'mean'``), then:
  478. .. math::
  479. \ell(x, y) =
  480. \begin{cases}
  481. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  482. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  483. \end{cases}
  484. :math:`x` and :math:`y` are tensors of arbitrary shapes with a total
  485. of :math:`N` elements each.
  486. The mean operation still operates over all the elements, and divides by :math:`N`.
  487. The division by :math:`N` can be avoided if one sets ``reduction = 'sum'``.
  488. Args:
  489. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  490. the losses are averaged over each loss element in the batch. Note that for
  491. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  492. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  493. when :attr:`reduce` is ``False``. Default: ``True``
  494. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  495. losses are averaged or summed over observations for each minibatch depending
  496. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  497. batch element instead and ignores :attr:`size_average`. Default: ``True``
  498. reduction (str, optional): Specifies the reduction to apply to the output:
  499. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  500. ``'mean'``: the sum of the output will be divided by the number of
  501. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  502. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  503. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  504. Shape:
  505. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  506. - Target: :math:`(*)`, same shape as the input.
  507. Examples:
  508. >>> loss = nn.MSELoss()
  509. >>> input = torch.randn(3, 5, requires_grad=True)
  510. >>> target = torch.randn(3, 5)
  511. >>> output = loss(input, target)
  512. >>> output.backward()
  513. """
  514. __constants__ = ["reduction"]
  515. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  516. """
  517. Runs the forward pass.
  518. """
  519. return F.mse_loss(input, target, reduction=self.reduction)
  520. class BCELoss(_WeightedLoss):
  521. r"""Creates a criterion that measures the Binary Cross Entropy between the target and
  522. the input probabilities:
  523. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
  524. .. math::
  525. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
  526. l_n = - w_n \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right],
  527. where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
  528. (default ``'mean'``), then
  529. .. math::
  530. \ell(x, y) = \begin{cases}
  531. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  532. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  533. \end{cases}
  534. This is used for measuring the error of a reconstruction in for example
  535. an auto-encoder. Note that the targets :math:`y` should be numbers
  536. between 0 and 1.
  537. Notice that if :math:`x_n` is either 0 or 1, one of the log terms would be
  538. mathematically undefined in the above loss equation. PyTorch chooses to set
  539. :math:`\log (0) = -\infty`, since :math:`\lim_{x\to 0} \log (x) = -\infty`.
  540. However, an infinite term in the loss equation is not desirable for several reasons.
  541. For one, if either :math:`y_n = 0` or :math:`(1 - y_n) = 0`, then we would be
  542. multiplying 0 with infinity. Secondly, if we have an infinite loss value, then
  543. we would also have an infinite term in our gradient, since
  544. :math:`\lim_{x\to 0} \frac{d}{dx} \log (x) = \infty`.
  545. This would make BCELoss's backward method nonlinear with respect to :math:`x_n`,
  546. and using it for things like linear regression would not be straight-forward.
  547. Our solution is that BCELoss clamps its log function outputs to be greater than
  548. or equal to -100. This way, we can always have a finite loss value and a linear
  549. backward method.
  550. Args:
  551. weight (Tensor, optional): a manual rescaling weight given to the loss
  552. of each batch element. If given, has to be a Tensor of size `nbatch`.
  553. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  554. the losses are averaged over each loss element in the batch. Note that for
  555. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  556. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  557. when :attr:`reduce` is ``False``. Default: ``True``
  558. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  559. losses are averaged or summed over observations for each minibatch depending
  560. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  561. batch element instead and ignores :attr:`size_average`. Default: ``True``
  562. reduction (str, optional): Specifies the reduction to apply to the output:
  563. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  564. ``'mean'``: the sum of the output will be divided by the number of
  565. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  566. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  567. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  568. Shape:
  569. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  570. - Target: :math:`(*)`, same shape as the input.
  571. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
  572. shape as input.
  573. Examples:
  574. >>> m = nn.Sigmoid()
  575. >>> loss = nn.BCELoss()
  576. >>> input = torch.randn(3, 2, requires_grad=True)
  577. >>> target = torch.rand(3, 2, requires_grad=False)
  578. >>> output = loss(m(input), target)
  579. >>> output.backward()
  580. """
  581. __constants__ = ["reduction"]
  582. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  583. """
  584. Runs the forward pass.
  585. """
  586. return F.binary_cross_entropy(
  587. input, target, weight=self.weight, reduction=self.reduction
  588. )
  589. class BCEWithLogitsLoss(_Loss):
  590. r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single
  591. class. This version is more numerically stable than using a plain `Sigmoid`
  592. followed by a `BCELoss` as, by combining the operations into one layer,
  593. we take advantage of the log-sum-exp trick for numerical stability.
  594. The unreduced (i.e. with :attr:`reduction` set to ``'none'``) loss can be described as:
  595. .. math::
  596. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
  597. l_n = - w_n \left[ y_n \cdot \log \sigma(x_n)
  598. + (1 - y_n) \cdot \log (1 - \sigma(x_n)) \right],
  599. where :math:`N` is the batch size. If :attr:`reduction` is not ``'none'``
  600. (default ``'mean'``), then
  601. .. math::
  602. \ell(x, y) = \begin{cases}
  603. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  604. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  605. \end{cases}
  606. This is used for measuring the error of a reconstruction in for example
  607. an auto-encoder. Note that the targets `t[i]` should be numbers
  608. between 0 and 1.
  609. It's possible to trade off recall and precision by adding weights to positive examples.
  610. In the case of multi-label classification the loss can be described as:
  611. .. math::
  612. \ell_c(x, y) = L_c = \{l_{1,c},\dots,l_{N,c}\}^\top, \quad
  613. l_{n,c} = - w_{n,c} \left[ p_c y_{n,c} \cdot \log \sigma(x_{n,c})
  614. + (1 - y_{n,c}) \cdot \log (1 - \sigma(x_{n,c})) \right],
  615. where :math:`c` is the class number (:math:`c > 1` for multi-label binary classification,
  616. :math:`c = 1` for single-label binary classification),
  617. :math:`n` is the number of the sample in the batch and
  618. :math:`p_c` is the weight of the positive answer for the class :math:`c`.
  619. :math:`p_c > 1` increases the recall, :math:`p_c < 1` increases the precision.
  620. For example, if a dataset contains 100 positive and 300 negative examples of a single class,
  621. then ``pos_weight`` for the class should be equal to :math:`\frac{300}{100}=3`.
  622. The loss would act as if the dataset contains :math:`3\times 100=300` positive examples.
  623. Examples:
  624. >>> target = torch.ones([10, 64], dtype=torch.float32) # 64 classes, batch size = 10
  625. >>> output = torch.full([10, 64], 1.5) # A prediction (logit)
  626. >>> pos_weight = torch.ones([64]) # All weights are equal to 1
  627. >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
  628. >>> criterion(output, target) # -log(sigmoid(1.5))
  629. tensor(0.20...)
  630. In the above example, the ``pos_weight`` tensor's elements correspond to the 64 distinct classes
  631. in a multi-label binary classification scenario. Each element in ``pos_weight`` is designed to adjust the
  632. loss function based on the imbalance between negative and positive samples for the respective class.
  633. This approach is useful in datasets with varying levels of class imbalance, ensuring that the loss
  634. calculation accurately accounts for the distribution in each class.
  635. Args:
  636. weight (Tensor, optional): a manual rescaling weight given to the loss
  637. of each batch element. The dimension of weight supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
  638. with respect to the output (and target) shape.
  639. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  640. the losses are averaged over each loss element in the batch. Note that for
  641. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  642. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  643. when :attr:`reduce` is ``False``. Default: ``True``
  644. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  645. losses are averaged or summed over observations for each minibatch depending
  646. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  647. batch element instead and ignores :attr:`size_average`. Default: ``True``
  648. reduction (str, optional): Specifies the reduction to apply to the output:
  649. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  650. ``'mean'``: the sum of the output will be divided by the number of
  651. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  652. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  653. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  654. pos_weight (Tensor, optional): a weight of positive examples to be broadcasted with target.
  655. Must be a tensor with equal size along the class dimension to the number of classes.
  656. Pay close attention to PyTorch's broadcasting semantics in order to achieve the desired
  657. operations. For a target of size [B, C, H, W] (where B is batch size) pos_weight of
  658. size [B, C, H, W] will apply different pos_weights to each element of the batch or
  659. [C, H, W] the same pos_weights across the batch. To apply the same positive weight
  660. along all spatial dimensions for a 2D multi-class target [C, H, W] use: [C, 1, 1].
  661. Default: ``None``
  662. Shape:
  663. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  664. - Target: :math:`(*)`, same shape as the input.
  665. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
  666. shape as input.
  667. Examples:
  668. >>> loss = nn.BCEWithLogitsLoss()
  669. >>> input = torch.randn(3, requires_grad=True)
  670. >>> target = torch.empty(3).random_(2)
  671. >>> output = loss(input, target)
  672. >>> output.backward()
  673. """
  674. def __init__(
  675. self,
  676. weight: Tensor | None = None,
  677. size_average=None,
  678. reduce=None,
  679. reduction: str = "mean",
  680. pos_weight: Tensor | None = None,
  681. ) -> None:
  682. super().__init__(size_average, reduce, reduction)
  683. self.register_buffer("weight", weight)
  684. self.register_buffer("pos_weight", pos_weight)
  685. self.weight: Tensor | None
  686. self.pos_weight: Tensor | None
  687. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  688. """Runs the forward pass."""
  689. return F.binary_cross_entropy_with_logits(
  690. input,
  691. target,
  692. self.weight,
  693. pos_weight=self.pos_weight,
  694. reduction=self.reduction,
  695. )
  696. class HingeEmbeddingLoss(_Loss):
  697. r"""Measures the loss given an input tensor :math:`x` and a labels tensor :math:`y`
  698. (containing 1 or -1).
  699. This is usually used for measuring whether two inputs are similar or
  700. dissimilar, e.g. using the L1 pairwise distance as :math:`x`, and is typically
  701. used for learning nonlinear embeddings or semi-supervised learning.
  702. The loss function for :math:`n`-th sample in the mini-batch is
  703. .. math::
  704. l_n = \begin{cases}
  705. x_n, & \text{if}\; y_n = 1,\\
  706. \max \{0, margin - x_n\}, & \text{if}\; y_n = -1,
  707. \end{cases}
  708. and the total loss functions is
  709. .. math::
  710. \ell(x, y) = \begin{cases}
  711. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  712. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  713. \end{cases}
  714. where :math:`L = \{l_1,\dots,l_N\}^\top`.
  715. Args:
  716. margin (float, optional): Has a default value of `1`.
  717. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  718. the losses are averaged over each loss element in the batch. Note that for
  719. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  720. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  721. when :attr:`reduce` is ``False``. Default: ``True``
  722. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  723. losses are averaged or summed over observations for each minibatch depending
  724. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  725. batch element instead and ignores :attr:`size_average`. Default: ``True``
  726. reduction (str, optional): Specifies the reduction to apply to the output:
  727. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  728. ``'mean'``: the sum of the output will be divided by the number of
  729. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  730. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  731. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  732. Shape:
  733. - Input: :math:`(*)` where :math:`*` means, any number of dimensions. The sum operation
  734. operates over all the elements.
  735. - Target: :math:`(*)`, same shape as the input
  736. - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input
  737. Examples:
  738. >>> loss = nn.HingeEmbeddingLoss()
  739. >>> input = torch.randn(3, 5, requires_grad=True)
  740. >>> target = torch.randn(3, 5).sign()
  741. >>> output = loss(input, target)
  742. >>> output.backward()
  743. """
  744. __constants__ = ["margin", "reduction"]
  745. margin: float
  746. def __init__(
  747. self,
  748. margin: float = 1.0,
  749. size_average=None,
  750. reduce=None,
  751. reduction: str = "mean",
  752. ) -> None:
  753. super().__init__(size_average, reduce, reduction)
  754. self.margin = margin
  755. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  756. """Runs the forward pass."""
  757. return F.hinge_embedding_loss(
  758. input, target, margin=self.margin, reduction=self.reduction
  759. )
  760. class MultiLabelMarginLoss(_Loss):
  761. r"""Creates a criterion that optimizes a multi-class multi-classification
  762. hinge loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`)
  763. and output :math:`y` (which is a 2D `Tensor` of target class indices).
  764. For each sample in the mini-batch:
  765. .. math::
  766. \text{loss}(x, y) = \sum_{ij}\frac{\max(0, 1 - (x[y[j]] - x[i]))}{\text{x.size}(0)}
  767. where :math:`x \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`, \
  768. :math:`y \in \left\{0, \; \cdots , \; \text{y.size}(0) - 1\right\}`, \
  769. :math:`0 \leq y[j] \leq \text{x.size}(0)-1`, \
  770. and :math:`i \neq y[j]` for all :math:`i` and :math:`j`.
  771. :math:`y` and :math:`x` must have the same size.
  772. The criterion only considers a contiguous block of non-negative targets that
  773. starts at the front.
  774. This allows for different samples to have variable amounts of target classes.
  775. Args:
  776. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  777. the losses are averaged over each loss element in the batch. Note that for
  778. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  779. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  780. when :attr:`reduce` is ``False``. Default: ``True``
  781. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  782. losses are averaged or summed over observations for each minibatch depending
  783. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  784. batch element instead and ignores :attr:`size_average`. Default: ``True``
  785. reduction (str, optional): Specifies the reduction to apply to the output:
  786. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  787. ``'mean'``: the sum of the output will be divided by the number of
  788. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  789. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  790. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  791. Shape:
  792. - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C`
  793. is the number of classes.
  794. - Target: :math:`(C)` or :math:`(N, C)`, label targets padded by -1 ensuring same shape as the input.
  795. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
  796. Examples:
  797. >>> loss = nn.MultiLabelMarginLoss()
  798. >>> x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
  799. >>> # for target y, only consider labels 3 and 0, not after label -1
  800. >>> y = torch.LongTensor([[3, 0, -1, 1]])
  801. >>> # 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
  802. >>> loss(x, y)
  803. tensor(0.85...)
  804. """
  805. __constants__ = ["reduction"]
  806. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  807. """Runs the forward pass."""
  808. return F.multilabel_margin_loss(input, target, reduction=self.reduction)
  809. class SmoothL1Loss(_Loss):
  810. r"""Creates a criterion that uses a squared term if the absolute
  811. element-wise error falls below beta and an L1 term otherwise.
  812. It is less sensitive to outliers than :class:`torch.nn.MSELoss` and in some cases
  813. prevents exploding gradients (e.g. see the paper `Fast R-CNN`_ by Ross Girshick).
  814. For a batch of size :math:`N`, the unreduced loss can be described as:
  815. .. math::
  816. \ell(x, y) = L = \{l_1, ..., l_N\}^T
  817. with
  818. .. math::
  819. l_n = \begin{cases}
  820. 0.5 (x_n - y_n)^2 / beta, & \text{if } |x_n - y_n| < beta \\
  821. |x_n - y_n| - 0.5 * beta, & \text{otherwise }
  822. \end{cases}
  823. If `reduction` is not `none`, then:
  824. .. math::
  825. \ell(x, y) =
  826. \begin{cases}
  827. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  828. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  829. \end{cases}
  830. .. note::
  831. Smooth L1 loss can be seen as exactly :class:`L1Loss`, but with the :math:`|x - y| < beta`
  832. portion replaced with a quadratic function such that its slope is 1 at :math:`|x - y| = beta`.
  833. The quadratic segment smooths the L1 loss near :math:`|x - y| = 0`.
  834. .. note::
  835. Smooth L1 loss is closely related to :class:`HuberLoss`, being
  836. equivalent to :math:`huber(x, y) / beta` (note that Smooth L1's beta hyper-parameter is
  837. also known as delta for Huber). This leads to the following differences:
  838. * As beta -> 0, Smooth L1 loss converges to :class:`L1Loss`, while :class:`HuberLoss`
  839. converges to a constant 0 loss. When beta is 0, Smooth L1 loss is equivalent to L1 loss.
  840. * As beta -> :math:`+\infty`, Smooth L1 loss converges to a constant 0 loss, while
  841. :class:`HuberLoss` converges to :class:`MSELoss`.
  842. * For Smooth L1 loss, as beta varies, the L1 segment of the loss has a constant slope of 1.
  843. For :class:`HuberLoss`, the slope of the L1 segment is beta.
  844. .. _`Fast R-CNN`: https://arxiv.org/abs/1504.08083
  845. Args:
  846. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  847. the losses are averaged over each loss element in the batch. Note that for
  848. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  849. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  850. when :attr:`reduce` is ``False``. Default: ``True``
  851. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  852. losses are averaged or summed over observations for each minibatch depending
  853. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  854. batch element instead and ignores :attr:`size_average`. Default: ``True``
  855. reduction (str, optional): Specifies the reduction to apply to the output:
  856. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  857. ``'mean'``: the sum of the output will be divided by the number of
  858. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  859. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  860. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  861. beta (float, optional): Specifies the threshold at which to change between L1 and L2 loss.
  862. The value must be non-negative. Default: 1.0
  863. Shape:
  864. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  865. - Target: :math:`(*)`, same shape as the input.
  866. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
  867. Examples:
  868. >>> loss = nn.SmoothL1Loss()
  869. >>> input = torch.randn(3, 5, requires_grad=True)
  870. >>> target = torch.randn(3, 5)
  871. >>> output = loss(input, target)
  872. >>> output.backward()
  873. """
  874. __constants__ = ["reduction"]
  875. def __init__(
  876. self, size_average=None, reduce=None, reduction: str = "mean", beta: float = 1.0
  877. ) -> None:
  878. super().__init__(size_average, reduce, reduction)
  879. self.beta = beta
  880. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  881. """Runs the forward pass."""
  882. return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)
  883. class HuberLoss(_Loss):
  884. r"""Creates a criterion that uses a squared term if the absolute
  885. element-wise error falls below delta and a delta-scaled L1 term otherwise.
  886. This loss combines advantages of both :class:`L1Loss` and :class:`MSELoss`; the
  887. delta-scaled L1 region makes the loss less sensitive to outliers than :class:`MSELoss`,
  888. while the L2 region provides smoothness over :class:`L1Loss` near 0. See
  889. `Huber loss <https://en.wikipedia.org/wiki/Huber_loss>`_ for more information.
  890. For a batch of size :math:`N`, the unreduced loss can be described as:
  891. .. math::
  892. \ell(x, y) = L = \{l_1, ..., l_N\}^T
  893. with
  894. .. math::
  895. l_n = \begin{cases}
  896. 0.5 (x_n - y_n)^2, & \text{if } |x_n - y_n| < delta \\
  897. delta * (|x_n - y_n| - 0.5 * delta), & \text{otherwise }
  898. \end{cases}
  899. If `reduction` is not `none`, then:
  900. .. math::
  901. \ell(x, y) =
  902. \begin{cases}
  903. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  904. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  905. \end{cases}
  906. .. note::
  907. When delta is set to 1, this loss is equivalent to :class:`SmoothL1Loss`.
  908. In general, this loss differs from :class:`SmoothL1Loss` by a factor of delta (AKA beta
  909. in Smooth L1).
  910. See :class:`SmoothL1Loss` for additional discussion on the differences in behavior
  911. between the two losses.
  912. Args:
  913. reduction (str, optional): Specifies the reduction to apply to the output:
  914. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  915. ``'mean'``: the sum of the output will be divided by the number of
  916. elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
  917. delta (float, optional): Specifies the threshold at which to change between delta-scaled L1 and L2 loss.
  918. The value must be positive. Default: 1.0
  919. Shape:
  920. - Input: :math:`(*)` where :math:`*` means any number of dimensions.
  921. - Target: :math:`(*)`, same shape as the input.
  922. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
  923. Examples:
  924. >>> loss = nn.HuberLoss()
  925. >>> input = torch.randn(3, 5, requires_grad=True)
  926. >>> target = torch.randn(3, 5)
  927. >>> output = loss(input, target)
  928. >>> output.backward()
  929. """
  930. __constants__ = ["reduction", "delta"]
  931. def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None:
  932. super().__init__(reduction=reduction)
  933. self.delta = delta
  934. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  935. """Runs the forward pass."""
  936. return F.huber_loss(input, target, reduction=self.reduction, delta=self.delta)
  937. class SoftMarginLoss(_Loss):
  938. r"""Creates a criterion that optimizes a two-class classification
  939. logistic loss between input tensor :math:`x` and target tensor :math:`y`
  940. (containing 1 or -1).
  941. .. math::
  942. \text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
  943. Args:
  944. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  945. the losses are averaged over each loss element in the batch. Note that for
  946. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  947. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  948. when :attr:`reduce` is ``False``. Default: ``True``
  949. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  950. losses are averaged or summed over observations for each minibatch depending
  951. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  952. batch element instead and ignores :attr:`size_average`. Default: ``True``
  953. reduction (str, optional): Specifies the reduction to apply to the output:
  954. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  955. ``'mean'``: the sum of the output will be divided by the number of
  956. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  957. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  958. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  959. Shape:
  960. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  961. - Target: :math:`(*)`, same shape as the input.
  962. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same
  963. shape as input.
  964. Examples:
  965. >>> loss = nn.SoftMarginLoss()
  966. >>> input = torch.randn(3, 5, requires_grad=True)
  967. >>> target = torch.randn(3, 5).sign()
  968. >>> output = loss(input, target)
  969. >>> output.backward()
  970. """
  971. __constants__ = ["reduction"]
  972. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  973. """Runs the forward pass."""
  974. return F.soft_margin_loss(input, target, reduction=self.reduction)
  975. class CrossEntropyLoss(_WeightedLoss):
  976. r"""This criterion computes the cross entropy loss between input logits
  977. and target.
  978. It is useful when training a classification problem with `C` classes.
  979. If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
  980. assigning weight to each of the classes.
  981. This is particularly useful when you have an unbalanced training set.
  982. The `input` is expected to contain the unnormalized logits for each class (which do `not` need
  983. to be positive or sum to 1, in general).
  984. `input` has to be a Tensor of size :math:`(C)` for unbatched input,
  985. :math:`(minibatch, C)` or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` for the
  986. `K`-dimensional case. The last being useful for higher dimension inputs, such
  987. as computing cross entropy loss per-pixel for 2D images.
  988. The `target` that this criterion expects should contain either:
  989. - Class indices in the range :math:`[0, C)` where :math:`C` is the number of classes; if
  990. `ignore_index` is specified, this loss also accepts this class index (this index
  991. may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction`
  992. set to ``'none'``) loss for this case can be described as:
  993. .. math::
  994. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
  995. l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
  996. \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
  997. where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
  998. :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as
  999. :math:`d_1, ..., d_k` for the `K`-dimensional case. If
  1000. :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
  1001. .. math::
  1002. \ell(x, y) = \begin{cases}
  1003. \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
  1004. \text{if reduction} = \text{`mean';}\\
  1005. \sum_{n=1}^N l_n, &
  1006. \text{if reduction} = \text{`sum'.}
  1007. \end{cases}
  1008. Note that this case is equivalent to applying :class:`~torch.nn.LogSoftmax`
  1009. on an input, followed by :class:`~torch.nn.NLLLoss`.
  1010. - Probabilities for each class; useful when labels beyond a single class per minibatch item
  1011. are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with
  1012. :attr:`reduction` set to ``'none'``) loss for this case can be described as:
  1013. .. math::
  1014. \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
  1015. l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} y_{n,c}
  1016. where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
  1017. :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension as well as
  1018. :math:`d_1, ..., d_k` for the `K`-dimensional case. If
  1019. :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
  1020. .. math::
  1021. \ell(x, y) = \begin{cases}
  1022. \frac{\sum_{n=1}^N l_n}{N}, &
  1023. \text{if reduction} = \text{`mean';}\\
  1024. \sum_{n=1}^N l_n, &
  1025. \text{if reduction} = \text{`sum'.}
  1026. \end{cases}
  1027. .. note::
  1028. The performance of this criterion is generally better when `target` contains class
  1029. indices, as this allows for optimized computation. Consider providing `target` as
  1030. class probabilities only when a single class label per minibatch item is too restrictive.
  1031. Args:
  1032. weight (Tensor, optional): a manual rescaling weight given to each class.
  1033. If given, has to be a Tensor of size `C`.
  1034. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  1035. the losses are averaged over each loss element in the batch. Note that for
  1036. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  1037. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  1038. when :attr:`reduce` is ``False``. Default: ``True``
  1039. ignore_index (int, optional): Specifies a target value that is ignored
  1040. and does not contribute to the input gradient. When :attr:`size_average` is
  1041. ``True``, the loss is averaged over non-ignored targets. Note that
  1042. :attr:`ignore_index` is only applicable when the target contains class indices.
  1043. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  1044. losses are averaged or summed over observations for each minibatch depending
  1045. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  1046. batch element instead and ignores :attr:`size_average`. Default: ``True``
  1047. reduction (str, optional): Specifies the reduction to apply to the output:
  1048. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
  1049. be applied, ``'mean'``: the weighted mean of the output is taken,
  1050. ``'sum'``: the output will be summed. Note: :attr:`size_average`
  1051. and :attr:`reduce` are in the process of being deprecated, and in
  1052. the meantime, specifying either of those two args will override
  1053. :attr:`reduction`. Default: ``'mean'``
  1054. label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
  1055. of smoothing when computing the loss, where 0.0 means no smoothing. The targets
  1056. become a mixture of the original ground truth and a uniform distribution as described in
  1057. `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`.
  1058. Shape:
  1059. - Input: Shape :math:`(C)`, :math:`(N, C)` or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
  1060. in the case of `K`-dimensional loss.
  1061. - Target: If containing class indices, shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with
  1062. :math:`K \geq 1` in the case of K-dimensional loss where each value should be between :math:`[0, C)`. The
  1063. target data type is required to be long when using class indices. If containing class probabilities, the
  1064. target must be the same shape input, and each value should be between :math:`[0, 1]`. This means the target
  1065. data type is required to be float when using class probabilities. Note that PyTorch does not strictly enforce
  1066. probability constraints on the class probabilities and that it is the user's responsibility to ensure
  1067. ``target`` contains valid probability distributions (see below examples section for more details).
  1068. - Output: If reduction is 'none', shape :math:`()`, :math:`(N)` or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
  1069. in the case of K-dimensional loss, depending on the shape of the input. Otherwise, scalar.
  1070. where:
  1071. .. math::
  1072. \begin{aligned}
  1073. C ={} & \text{number of classes} \\
  1074. N ={} & \text{batch size} \\
  1075. \end{aligned}
  1076. Examples:
  1077. >>> # Example of target with class indices
  1078. >>> loss = nn.CrossEntropyLoss()
  1079. >>> input = torch.randn(3, 5, requires_grad=True)
  1080. >>> target = torch.empty(3, dtype=torch.long).random_(5)
  1081. >>> output = loss(input, target)
  1082. >>> output.backward()
  1083. >>>
  1084. >>> # Example of target with class probabilities
  1085. >>> input = torch.randn(3, 5, requires_grad=True)
  1086. >>> target = torch.randn(3, 5).softmax(dim=1)
  1087. >>> output = loss(input, target)
  1088. >>> output.backward()
  1089. .. note::
  1090. When ``target`` contains class probabilities, it should consist of soft labels—that is,
  1091. each ``target`` entry should represent a probability distribution over the possible classes for a given data sample,
  1092. with individual probabilities between ``[0,1]`` and the total distribution summing to 1.
  1093. This is why the :func:`softmax()` function is applied to the ``target`` in the class probabilities example above.
  1094. PyTorch does not validate whether the values provided in ``target`` lie in the range ``[0,1]``
  1095. or whether the distribution of each data sample sums to ``1``.
  1096. No warning will be raised and it is the user's responsibility
  1097. to ensure that ``target`` contains valid probability distributions.
  1098. Providing arbitrary values may yield misleading loss values and unstable gradients during training.
  1099. Examples:
  1100. >>> # xdoctest: +SKIP
  1101. >>> # Example of target with incorrectly specified class probabilities
  1102. >>> loss = nn.CrossEntropyLoss()
  1103. >>> torch.manual_seed(283)
  1104. >>> input = torch.randn(3, 5, requires_grad=True)
  1105. >>> target = torch.randn(3, 5)
  1106. >>> # Provided target class probabilities are not in range [0,1]
  1107. >>> target
  1108. tensor([[ 0.7105, 0.4446, 2.0297, 0.2671, -0.6075],
  1109. [-1.0496, -0.2753, -0.3586, 0.9270, 1.0027],
  1110. [ 0.7551, 0.1003, 1.3468, -0.3581, -0.9569]])
  1111. >>> # Provided target class probabilities do not sum to 1
  1112. >>> target.sum(axis=1)
  1113. tensor([2.8444, 0.2462, 0.8873])
  1114. >>> # No error message and possible misleading loss value
  1115. >>> loss(input, target).item()
  1116. 4.6379876136779785
  1117. >>>
  1118. >>> # Example of target with correctly specified class probabilities
  1119. >>> # Use .softmax() to ensure true probability distribution
  1120. >>> target_new = target.softmax(dim=1)
  1121. >>> # New target class probabilities all in range [0,1]
  1122. >>> target_new
  1123. tensor([[0.1559, 0.1195, 0.5830, 0.1000, 0.0417],
  1124. [0.0496, 0.1075, 0.0990, 0.3579, 0.3860],
  1125. [0.2607, 0.1355, 0.4711, 0.0856, 0.0471]])
  1126. >>> # New target class probabilities sum to 1
  1127. >>> target_new.sum(axis=1)
  1128. tensor([1.0000, 1.0000, 1.0000])
  1129. >>> loss(input, target_new).item()
  1130. 2.55349063873291
  1131. """
  1132. __constants__ = ["ignore_index", "reduction", "label_smoothing"]
  1133. ignore_index: int
  1134. label_smoothing: float
  1135. def __init__(
  1136. self,
  1137. weight: Tensor | None = None,
  1138. size_average=None,
  1139. ignore_index: int = -100,
  1140. reduce=None,
  1141. reduction: str = "mean",
  1142. label_smoothing: float = 0.0,
  1143. ) -> None:
  1144. super().__init__(weight, size_average, reduce, reduction)
  1145. self.ignore_index = ignore_index
  1146. self.label_smoothing = label_smoothing
  1147. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  1148. """Runs the forward pass."""
  1149. return F.cross_entropy(
  1150. input,
  1151. target,
  1152. weight=self.weight,
  1153. ignore_index=self.ignore_index,
  1154. reduction=self.reduction,
  1155. label_smoothing=self.label_smoothing,
  1156. )
  1157. class MultiLabelSoftMarginLoss(_WeightedLoss):
  1158. r"""Creates a criterion that optimizes a multi-label one-versus-all
  1159. loss based on max-entropy, between input :math:`x` and target :math:`y` of size
  1160. :math:`(N, C)`.
  1161. For each sample in the minibatch:
  1162. .. math::
  1163. loss(x, y) = - \frac{1}{C} * \sum_i y[i] * \log((1 + \exp(-x[i]))^{-1})
  1164. + (1-y[i]) * \log\left(\frac{\exp(-x[i])}{(1 + \exp(-x[i]))}\right)
  1165. where :math:`i \in \left\{0, \; \cdots , \; \text{x.nElement}() - 1\right\}`,
  1166. :math:`y[i] \in \left\{0, \; 1\right\}`.
  1167. Args:
  1168. weight (Tensor, optional): a manual rescaling weight given to each
  1169. class. If given, it has to be a Tensor of size `C`. Otherwise, it is
  1170. treated as if having all ones.
  1171. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  1172. the losses are averaged over each loss element in the batch. Note that for
  1173. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  1174. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  1175. when :attr:`reduce` is ``False``. Default: ``True``
  1176. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  1177. losses are averaged or summed over observations for each minibatch depending
  1178. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  1179. batch element instead and ignores :attr:`size_average`. Default: ``True``
  1180. reduction (str, optional): Specifies the reduction to apply to the output:
  1181. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1182. ``'mean'``: the sum of the output will be divided by the number of
  1183. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  1184. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  1185. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  1186. Shape:
  1187. - Input: :math:`(N, C)` where `N` is the batch size and `C` is the number of classes.
  1188. - Target: :math:`(N, C)`, label targets must have the same shape as the input.
  1189. - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
  1190. Examples:
  1191. >>> loss = nn.MultiLabelSoftMarginLoss()
  1192. >>> input = torch.randn(3, 5, requires_grad=True)
  1193. >>> target = torch.empty(3, 5).random_(2)
  1194. >>> output = loss(input, target)
  1195. >>> output.backward()
  1196. """
  1197. __constants__ = ["reduction"]
  1198. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  1199. """Runs the forward pass."""
  1200. return F.multilabel_soft_margin_loss(
  1201. input, target, weight=self.weight, reduction=self.reduction
  1202. )
  1203. class CosineEmbeddingLoss(_Loss):
  1204. r"""Creates a criterion that measures the loss given input tensors
  1205. :math:`x_1`, :math:`x_2` and a `Tensor` label :math:`y` with values 1 or -1.
  1206. Use (:math:`y=1`) to maximize the cosine similarity of two inputs, and (:math:`y=-1`) otherwise.
  1207. This is typically used for learning nonlinear
  1208. embeddings or semi-supervised learning.
  1209. The loss function for each sample is:
  1210. .. math::
  1211. \text{loss}(x, y) =
  1212. \begin{cases}
  1213. 1 - \cos(x_1, x_2), & \text{if } y = 1 \\
  1214. \max(0, \cos(x_1, x_2) - \text{margin}), & \text{if } y = -1
  1215. \end{cases}
  1216. Args:
  1217. margin (float, optional): Should be a number from :math:`-1` to :math:`1`,
  1218. :math:`0` to :math:`0.5` is suggested. If :attr:`margin` is missing, the
  1219. default value is :math:`0`.
  1220. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  1221. the losses are averaged over each loss element in the batch. Note that for
  1222. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  1223. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  1224. when :attr:`reduce` is ``False``. Default: ``True``
  1225. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  1226. losses are averaged or summed over observations for each minibatch depending
  1227. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  1228. batch element instead and ignores :attr:`size_average`. Default: ``True``
  1229. reduction (str, optional): Specifies the reduction to apply to the output:
  1230. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1231. ``'mean'``: the sum of the output will be divided by the number of
  1232. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  1233. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  1234. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  1235. Shape:
  1236. - Input1: :math:`(N, D)` or :math:`(D)`, where `N` is the batch size and `D` is the embedding dimension.
  1237. - Input2: :math:`(N, D)` or :math:`(D)`, same shape as Input1.
  1238. - Target: :math:`(N)` or :math:`()`.
  1239. - Output: If :attr:`reduction` is ``'none'``, then :math:`(N)`, otherwise scalar.
  1240. Examples:
  1241. >>> loss = nn.CosineEmbeddingLoss()
  1242. >>> input1 = torch.randn(3, 5, requires_grad=True)
  1243. >>> input2 = torch.randn(3, 5, requires_grad=True)
  1244. >>> target = torch.ones(3)
  1245. >>> output = loss(input1, input2, target)
  1246. >>> output.backward()
  1247. """
  1248. __constants__ = ["margin", "reduction"]
  1249. margin: float
  1250. def __init__(
  1251. self,
  1252. margin: float = 0.0,
  1253. size_average=None,
  1254. reduce=None,
  1255. reduction: str = "mean",
  1256. ) -> None:
  1257. super().__init__(size_average, reduce, reduction)
  1258. self.margin = margin
  1259. def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor:
  1260. """Runs the forward pass."""
  1261. return F.cosine_embedding_loss(
  1262. input1, input2, target, margin=self.margin, reduction=self.reduction
  1263. )
  1264. class MarginRankingLoss(_Loss):
  1265. r"""Creates a criterion that measures the loss given
  1266. inputs :math:`x1`, :math:`x2`, two 1D mini-batch or 0D `Tensors`,
  1267. and a label 1D mini-batch or 0D `Tensor` :math:`y` (containing 1 or -1).
  1268. If :math:`y = 1` then it assumed the first input should be ranked higher
  1269. (have a larger value) than the second input, and vice-versa for :math:`y = -1`.
  1270. The loss function for each pair of samples in the mini-batch is:
  1271. .. math::
  1272. \text{loss}(x1, x2, y) = \max(0, -y * (x1 - x2) + \text{margin})
  1273. Args:
  1274. margin (float, optional): Has a default value of :math:`0`.
  1275. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  1276. the losses are averaged over each loss element in the batch. Note that for
  1277. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  1278. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  1279. when :attr:`reduce` is ``False``. Default: ``True``
  1280. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  1281. losses are averaged or summed over observations for each minibatch depending
  1282. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  1283. batch element instead and ignores :attr:`size_average`. Default: ``True``
  1284. reduction (str, optional): Specifies the reduction to apply to the output:
  1285. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1286. ``'mean'``: the sum of the output will be divided by the number of
  1287. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  1288. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  1289. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  1290. Shape:
  1291. - Input1: :math:`(N)` or :math:`()` where `N` is the batch size.
  1292. - Input2: :math:`(N)` or :math:`()`, same shape as the Input1.
  1293. - Target: :math:`(N)` or :math:`()`, same shape as the inputs.
  1294. - Output: scalar. If :attr:`reduction` is ``'none'`` and Input size is not :math:`()`, then :math:`(N)`.
  1295. Examples:
  1296. >>> loss = nn.MarginRankingLoss()
  1297. >>> input1 = torch.randn(3, requires_grad=True)
  1298. >>> input2 = torch.randn(3, requires_grad=True)
  1299. >>> target = torch.randn(3).sign()
  1300. >>> output = loss(input1, input2, target)
  1301. >>> output.backward()
  1302. """
  1303. __constants__ = ["margin", "reduction"]
  1304. margin: float
  1305. def __init__(
  1306. self,
  1307. margin: float = 0.0,
  1308. size_average=None,
  1309. reduce=None,
  1310. reduction: str = "mean",
  1311. ) -> None:
  1312. super().__init__(size_average, reduce, reduction)
  1313. self.margin = margin
  1314. def forward(self, input1: Tensor, input2: Tensor, target: Tensor) -> Tensor:
  1315. """Runs the forward pass."""
  1316. return F.margin_ranking_loss(
  1317. input1, input2, target, margin=self.margin, reduction=self.reduction
  1318. )
  1319. class MultiMarginLoss(_WeightedLoss):
  1320. r"""Creates a criterion that optimizes a multi-class classification hinge
  1321. loss (margin-based loss) between input :math:`x` (a 2D mini-batch `Tensor`) and
  1322. output :math:`y` (which is a 1D tensor of target class indices,
  1323. :math:`0 \leq y \leq \text{x.size}(1)-1`):
  1324. For each mini-batch sample, the loss in terms of the 1D input :math:`x` and scalar
  1325. output :math:`y` is:
  1326. .. math::
  1327. \text{loss}(x, y) = \frac{\sum_i \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)}
  1328. where :math:`i \in \left\{0, \; \cdots , \; \text{x.size}(0) - 1\right\}`
  1329. and :math:`i \neq y`.
  1330. Optionally, you can give non-equal weighting on the classes by passing
  1331. a 1D :attr:`weight` tensor into the constructor.
  1332. The loss function then becomes:
  1333. .. math::
  1334. \text{loss}(x, y) = \frac{\sum_i w[y] * \max(0, \text{margin} - x[y] + x[i])^p}{\text{x.size}(0)}
  1335. Args:
  1336. p (int, optional): Has a default value of :math:`1`. :math:`1` and :math:`2`
  1337. are the only supported values.
  1338. margin (float, optional): Has a default value of :math:`1`.
  1339. weight (Tensor, optional): a manual rescaling weight given to each
  1340. class. If given, it has to be a Tensor of size `C`. Otherwise, it is
  1341. treated as if having all ones.
  1342. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  1343. the losses are averaged over each loss element in the batch. Note that for
  1344. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  1345. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  1346. when :attr:`reduce` is ``False``. Default: ``True``
  1347. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  1348. losses are averaged or summed over observations for each minibatch depending
  1349. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  1350. batch element instead and ignores :attr:`size_average`. Default: ``True``
  1351. reduction (str, optional): Specifies the reduction to apply to the output:
  1352. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1353. ``'mean'``: the sum of the output will be divided by the number of
  1354. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  1355. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  1356. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  1357. Shape:
  1358. - Input: :math:`(N, C)` or :math:`(C)`, where :math:`N` is the batch size and :math:`C` is the number of classes.
  1359. - Target: :math:`(N)` or :math:`()`, where each value is :math:`0 \leq \text{targets}[i] \leq C-1`.
  1360. - Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the target.
  1361. Examples:
  1362. >>> loss = nn.MultiMarginLoss()
  1363. >>> x = torch.tensor([[0.1, 0.2, 0.4, 0.8]])
  1364. >>> y = torch.tensor([3])
  1365. >>> # 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
  1366. >>> loss(x, y)
  1367. tensor(0.32...)
  1368. """
  1369. __constants__ = ["p", "margin", "reduction"]
  1370. margin: float
  1371. p: int
  1372. def __init__(
  1373. self,
  1374. p: int = 1,
  1375. margin: float = 1.0,
  1376. weight: Tensor | None = None,
  1377. size_average=None,
  1378. reduce=None,
  1379. reduction: str = "mean",
  1380. ) -> None:
  1381. super().__init__(weight, size_average, reduce, reduction)
  1382. if p != 1 and p != 2:
  1383. raise ValueError("only p == 1 and p == 2 supported")
  1384. if weight is not None and weight.dim() != 1:
  1385. raise ValueError(
  1386. f"MultiMarginLoss: expected weight to be None or 1D tensor, got {weight.dim()}D instead"
  1387. )
  1388. self.p = p
  1389. self.margin = margin
  1390. def forward(self, input: Tensor, target: Tensor) -> Tensor:
  1391. """Runs the forward pass."""
  1392. return F.multi_margin_loss(
  1393. input,
  1394. target,
  1395. p=self.p,
  1396. margin=self.margin,
  1397. weight=self.weight,
  1398. reduction=self.reduction,
  1399. )
  1400. class TripletMarginLoss(_Loss):
  1401. r"""Creates a criterion that measures the triplet loss given an input
  1402. tensors :math:`x1`, :math:`x2`, :math:`x3` and a margin with a value greater than :math:`0`.
  1403. This is used for measuring a relative similarity between samples. A triplet
  1404. is composed by `a`, `p` and `n` (i.e., `anchor`, `positive examples` and `negative
  1405. examples` respectively). The shapes of all input tensors should be
  1406. :math:`(N, D)`.
  1407. The distance swap is described in detail in the paper `Learning shallow
  1408. convolutional feature descriptors with triplet losses`_ by
  1409. V. Balntas, E. Riba et al.
  1410. The loss function for each sample in the mini-batch is:
  1411. .. math::
  1412. L(a, p, n) = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
  1413. where
  1414. .. math::
  1415. d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
  1416. The norm is calculated using the specified p value and a small constant :math:`\varepsilon` is
  1417. added for numerical stability.
  1418. See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
  1419. triplet margin loss for input tensors using a custom distance function.
  1420. Args:
  1421. margin (float, optional): Default: :math:`1`.
  1422. p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
  1423. eps (float, optional): Small constant for numerical stability. Default: :math:`1e-6`.
  1424. swap (bool, optional): The distance swap is described in detail in the paper
  1425. `Learning shallow convolutional feature descriptors with triplet losses` by
  1426. V. Balntas, E. Riba et al. Default: ``False``.
  1427. size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
  1428. the losses are averaged over each loss element in the batch. Note that for
  1429. some losses, there are multiple elements per sample. If the field :attr:`size_average`
  1430. is set to ``False``, the losses are instead summed for each minibatch. Ignored
  1431. when :attr:`reduce` is ``False``. Default: ``True``
  1432. reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
  1433. losses are averaged or summed over observations for each minibatch depending
  1434. on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
  1435. batch element instead and ignores :attr:`size_average`. Default: ``True``
  1436. reduction (str, optional): Specifies the reduction to apply to the output:
  1437. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1438. ``'mean'``: the sum of the output will be divided by the number of
  1439. elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
  1440. and :attr:`reduce` are in the process of being deprecated, and in the meantime,
  1441. specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
  1442. Shape:
  1443. - Input: :math:`(N, D)` or :math:`(D)` where :math:`D` is the vector dimension.
  1444. - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'`` and
  1445. input shape is :math:`(N, D)`; a scalar otherwise.
  1446. Examples:
  1447. >>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
  1448. >>> anchor = torch.randn(100, 128, requires_grad=True)
  1449. >>> positive = torch.randn(100, 128, requires_grad=True)
  1450. >>> negative = torch.randn(100, 128, requires_grad=True)
  1451. >>> output = triplet_loss(anchor, positive, negative)
  1452. >>> output.backward()
  1453. .. _Learning shallow convolutional feature descriptors with triplet losses:
  1454. https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
  1455. """
  1456. __constants__ = ["margin", "p", "eps", "swap", "reduction"]
  1457. margin: float
  1458. p: float
  1459. eps: float
  1460. swap: bool
  1461. def __init__(
  1462. self,
  1463. margin: float = 1.0,
  1464. p: float = 2.0,
  1465. eps: float = 1e-6,
  1466. swap: bool = False,
  1467. size_average=None,
  1468. reduce=None,
  1469. reduction: str = "mean",
  1470. ) -> None:
  1471. super().__init__(size_average, reduce, reduction)
  1472. if margin <= 0:
  1473. raise ValueError(
  1474. f"TripletMarginLoss: expected margin to be greater than 0, got {margin} instead"
  1475. )
  1476. self.margin = margin
  1477. self.p = p
  1478. self.eps = eps
  1479. self.swap = swap
  1480. def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
  1481. """Runs the forward pass."""
  1482. return F.triplet_margin_loss(
  1483. anchor,
  1484. positive,
  1485. negative,
  1486. margin=self.margin,
  1487. p=self.p,
  1488. eps=self.eps,
  1489. swap=self.swap,
  1490. reduction=self.reduction,
  1491. )
  1492. class TripletMarginWithDistanceLoss(_Loss):
  1493. r"""Creates a criterion that measures the triplet loss given input
  1494. tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
  1495. positive, and negative examples, respectively), and a nonnegative,
  1496. real-valued function ("distance function") used to compute the relationship
  1497. between the anchor and positive example ("positive distance") and the
  1498. anchor and negative example ("negative distance").
  1499. The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
  1500. can be described as:
  1501. .. math::
  1502. \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
  1503. l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
  1504. where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
  1505. quantifying the closeness of two tensors, referred to as the :attr:`distance_function`;
  1506. and :math:`margin` is a nonnegative margin representing the minimum difference
  1507. between the positive and negative distances that is required for the loss to
  1508. be 0. The input tensors have :math:`N` elements each and can be of any shape
  1509. that the distance function can handle.
  1510. If :attr:`reduction` is not ``'none'``
  1511. (default ``'mean'``), then:
  1512. .. math::
  1513. \ell(x, y) =
  1514. \begin{cases}
  1515. \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
  1516. \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
  1517. \end{cases}
  1518. See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
  1519. loss for input tensors using the :math:`l_p` distance as the distance function.
  1520. Args:
  1521. distance_function (Callable, optional): A nonnegative, real-valued function that
  1522. quantifies the closeness of two tensors. If not specified,
  1523. `nn.PairwiseDistance` will be used. Default: ``None``
  1524. margin (float, optional): A nonnegative margin representing the minimum difference
  1525. between the positive and negative distances required for the loss to be 0. Larger
  1526. margins penalize cases where the negative examples are not distant enough from the
  1527. anchors, relative to the positives. Default: :math:`1`.
  1528. swap (bool, optional): Whether to use the distance swap described in the paper
  1529. `Learning shallow convolutional feature descriptors with triplet losses` by
  1530. V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
  1531. negative example than the anchor is, swaps the positive example and the anchor in
  1532. the loss computation. Default: ``False``.
  1533. reduction (str, optional): Specifies the (optional) reduction to apply to the output:
  1534. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1535. ``'mean'``: the sum of the output will be divided by the number of
  1536. elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
  1537. Shape:
  1538. - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
  1539. as supported by the distance function.
  1540. - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
  1541. otherwise.
  1542. Examples:
  1543. >>> # Initialize embeddings
  1544. >>> embedding = nn.Embedding(1000, 128)
  1545. >>> anchor_ids = torch.randint(0, 1000, (1,))
  1546. >>> positive_ids = torch.randint(0, 1000, (1,))
  1547. >>> negative_ids = torch.randint(0, 1000, (1,))
  1548. >>> anchor = embedding(anchor_ids)
  1549. >>> positive = embedding(positive_ids)
  1550. >>> negative = embedding(negative_ids)
  1551. >>>
  1552. >>> # Built-in Distance Function
  1553. >>> triplet_loss = \
  1554. >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
  1555. >>> output = triplet_loss(anchor, positive, negative)
  1556. >>> output.backward()
  1557. >>>
  1558. >>> # Custom Distance Function
  1559. >>> def l_infinity(x1, x2):
  1560. >>> return torch.max(torch.abs(x1 - x2), dim=1).values
  1561. >>>
  1562. >>> # xdoctest: +SKIP("FIXME: Would call backwards a second time")
  1563. >>> triplet_loss = (
  1564. >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5))
  1565. >>> output = triplet_loss(anchor, positive, negative)
  1566. >>> output.backward()
  1567. >>>
  1568. >>> # Custom Distance Function (Lambda)
  1569. >>> triplet_loss = (
  1570. >>> nn.TripletMarginWithDistanceLoss(
  1571. >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
  1572. >>> output = triplet_loss(anchor, positive, negative)
  1573. >>> output.backward()
  1574. Reference:
  1575. V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
  1576. https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
  1577. """
  1578. __constants__ = ["margin", "swap", "reduction"]
  1579. margin: float
  1580. swap: bool
  1581. def __init__(
  1582. self,
  1583. *,
  1584. distance_function: Callable[[Tensor, Tensor], Tensor] | None = None,
  1585. margin: float = 1.0,
  1586. swap: bool = False,
  1587. reduction: str = "mean",
  1588. ) -> None:
  1589. super().__init__(size_average=None, reduce=None, reduction=reduction)
  1590. if margin <= 0:
  1591. raise ValueError(
  1592. f"TripletMarginWithDistanceLoss: expected margin to be greater than 0, got {margin} instead"
  1593. )
  1594. self.distance_function: Callable[[Tensor, Tensor], Tensor] | None = (
  1595. # pyrefly: ignore [bad-assignment]
  1596. distance_function if distance_function is not None else PairwiseDistance()
  1597. )
  1598. self.margin = margin
  1599. self.swap = swap
  1600. def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
  1601. """Runs the forward pass."""
  1602. return F.triplet_margin_with_distance_loss(
  1603. anchor,
  1604. positive,
  1605. negative,
  1606. distance_function=self.distance_function,
  1607. margin=self.margin,
  1608. swap=self.swap,
  1609. reduction=self.reduction,
  1610. )
  1611. class CTCLoss(_Loss):
  1612. r"""The Connectionist Temporal Classification loss.
  1613. Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the
  1614. probability of possible alignments of input to target, producing a loss value which is differentiable
  1615. with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which
  1616. limits the length of the target sequence such that it must be :math:`\leq` the input length.
  1617. Args:
  1618. blank (int, optional): blank label. Default :math:`0`.
  1619. reduction (str, optional): Specifies the reduction to apply to the output:
  1620. ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
  1621. ``'mean'``: the output losses will be divided by the target lengths and
  1622. then the mean over the batch is taken, ``'sum'``: the output losses will be summed.
  1623. Default: ``'mean'``
  1624. zero_infinity (bool, optional):
  1625. Whether to zero infinite losses and the associated gradients.
  1626. Default: ``False``
  1627. Infinite losses mainly occur when the inputs are too short
  1628. to be aligned to the targets.
  1629. Shape:
  1630. - Log_probs: Tensor of size :math:`(T, N, C)` or :math:`(T, C)`,
  1631. where :math:`T = \text{input length}`,
  1632. :math:`N = \text{batch size}`, and
  1633. :math:`C = \text{number of classes (including blank)}`.
  1634. The logarithmized probabilities of the outputs (e.g. obtained with
  1635. :func:`torch.nn.functional.log_softmax`).
  1636. - Targets: Tensor of size :math:`(N, S)` or
  1637. :math:`(\operatorname{sum}(\text{target\_lengths}))`,
  1638. where :math:`N = \text{batch size}` and
  1639. :math:`S = \text{max target length, if shape is } (N, S)`.
  1640. It represents the target sequences. Each element in the target
  1641. sequence is a class index. And the target index cannot be blank (default=0).
  1642. In the :math:`(N, S)` form, targets are padded to the
  1643. length of the longest sequence, and stacked.
  1644. In the :math:`(\operatorname{sum}(\text{target\_lengths}))` form,
  1645. the targets are assumed to be un-padded and
  1646. concatenated within 1 dimension.
  1647. - Input_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`,
  1648. where :math:`N = \text{batch size}`. It represents the lengths of the
  1649. inputs (must each be :math:`\leq T`). And the lengths are specified
  1650. for each sequence to achieve masking under the assumption that sequences
  1651. are padded to equal lengths.
  1652. - Target_lengths: Tuple or tensor of size :math:`(N)` or :math:`()`,
  1653. where :math:`N = \text{batch size}`. It represents lengths of the targets.
  1654. Lengths are specified for each sequence to achieve masking under the
  1655. assumption that sequences are padded to equal lengths. If target shape is
  1656. :math:`(N,S)`, target_lengths are effectively the stop index
  1657. :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for
  1658. each target in a batch. Lengths must each be :math:`\leq S`
  1659. If the targets are given as a 1d tensor that is the concatenation of individual
  1660. targets, the target_lengths must add up to the total length of the tensor.
  1661. - Output: scalar if :attr:`reduction` is ``'mean'`` (default) or
  1662. ``'sum'``. If :attr:`reduction` is ``'none'``, then :math:`(N)` if input is batched or
  1663. :math:`()` if input is unbatched, where :math:`N = \text{batch size}`.
  1664. Examples:
  1665. >>> # Target are to be padded
  1666. >>> T = 50 # Input sequence length
  1667. >>> C = 20 # Number of classes (including blank)
  1668. >>> N = 16 # Batch size
  1669. >>> S = 30 # Target sequence length of longest target in batch (padding length)
  1670. >>> S_min = 10 # Minimum target length, for demonstration purposes
  1671. >>>
  1672. >>> # Initialize random batch of input vectors, for *size = (T,N,C)
  1673. >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
  1674. >>>
  1675. >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
  1676. >>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
  1677. >>>
  1678. >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
  1679. >>> target_lengths = torch.randint(
  1680. ... low=S_min,
  1681. ... high=S,
  1682. ... size=(N,),
  1683. ... dtype=torch.long,
  1684. ... )
  1685. >>> ctc_loss = nn.CTCLoss()
  1686. >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
  1687. >>> loss.backward()
  1688. >>>
  1689. >>>
  1690. >>> # Target are to be un-padded
  1691. >>> T = 50 # Input sequence length
  1692. >>> C = 20 # Number of classes (including blank)
  1693. >>> N = 16 # Batch size
  1694. >>>
  1695. >>> # Initialize random batch of input vectors, for *size = (T,N,C)
  1696. >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
  1697. >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
  1698. >>>
  1699. >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
  1700. >>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
  1701. >>> target = torch.randint(
  1702. ... low=1,
  1703. ... high=C,
  1704. ... size=(sum(target_lengths),),
  1705. ... dtype=torch.long,
  1706. ... )
  1707. >>> ctc_loss = nn.CTCLoss()
  1708. >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
  1709. >>> loss.backward()
  1710. >>>
  1711. >>>
  1712. >>> # Target are to be un-padded and unbatched (effectively N=1)
  1713. >>> T = 50 # Input sequence length
  1714. >>> C = 20 # Number of classes (including blank)
  1715. >>>
  1716. >>> # Initialize random batch of input vectors, for *size = (T,C)
  1717. >>> # xdoctest: +SKIP("FIXME: error in doctest")
  1718. >>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
  1719. >>> input_lengths = torch.tensor(T, dtype=torch.long)
  1720. >>>
  1721. >>> # Initialize random batch of targets (0 = blank, 1:C = classes)
  1722. >>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
  1723. >>> target = torch.randint(
  1724. ... low=1,
  1725. ... high=C,
  1726. ... size=(target_lengths,),
  1727. ... dtype=torch.long,
  1728. ... )
  1729. >>> ctc_loss = nn.CTCLoss()
  1730. >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
  1731. >>> loss.backward()
  1732. Reference:
  1733. A. Graves et al.: Connectionist Temporal Classification:
  1734. Labelling Unsegmented Sequence Data with Recurrent Neural Networks:
  1735. https://www.cs.toronto.edu/~graves/icml_2006.pdf
  1736. Note:
  1737. In order to use CuDNN, the following must be satisfied: the :attr:`targets` must be
  1738. in concatenated format, all :attr:`input_lengths` must be `T`. :math:`blank=0`,
  1739. :attr:`target_lengths` :math:`\leq 256`, the integer arguments must be of
  1740. dtype :attr:`torch.int32`, and the :attr:`log_probs` itself must be of
  1741. dtype :attr:`torch.float32`.
  1742. The regular implementation uses the (more common in PyTorch) `torch.long` dtype.
  1743. Note:
  1744. In some circumstances when using the CUDA backend with CuDNN, this operator
  1745. may select a nondeterministic algorithm to increase performance. If this is
  1746. undesirable, you can try to make the operation deterministic (potentially at
  1747. a performance cost) by setting ``torch.backends.cudnn.deterministic =
  1748. True``.
  1749. Please see the notes on :doc:`/notes/randomness` for background.
  1750. """
  1751. __constants__ = ["blank", "reduction"]
  1752. blank: int
  1753. zero_infinity: bool
  1754. def __init__(
  1755. self, blank: int = 0, reduction: str = "mean", zero_infinity: bool = False
  1756. ) -> None:
  1757. super().__init__(reduction=reduction)
  1758. self.blank = blank
  1759. self.zero_infinity = zero_infinity
  1760. def forward(
  1761. self,
  1762. log_probs: Tensor,
  1763. targets: Tensor,
  1764. input_lengths: Tensor,
  1765. target_lengths: Tensor,
  1766. ) -> Tensor:
  1767. """Runs the forward pass."""
  1768. return F.ctc_loss(
  1769. log_probs,
  1770. targets,
  1771. input_lengths,
  1772. target_lengths,
  1773. self.blank,
  1774. self.reduction,
  1775. self.zero_infinity,
  1776. )
  1777. # TODO: L1HingeEmbeddingCriterion
  1778. # TODO: MSECriterion weight
  1779. # TODO: ClassSimplexCriterion