_ops.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from collections.abc import Callable
  4. from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar
  5. from typing_extensions import ParamSpec
  6. import torch
  7. from torch import sym_float, Tensor
  8. from torch._prims_common import corresponding_real_dtype
  9. from torch.masked import _docs
  10. from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
  11. from torch.masked.maskedtensor.creation import as_masked_tensor
  12. if TYPE_CHECKING:
  13. from torch._prims_common import DimsType
  14. from torch.types import _dtype as DType
  15. DimOrDims: TypeAlias = DimsType | None
  16. else:
  17. # The JIT doesn't understand Union, nor torch.dtype here
  18. DType = int
  19. DimOrDims = Optional[tuple[int, ...]]
  20. __all__: list[str] = []
  21. _T = TypeVar("_T")
  22. _P = ParamSpec("_P")
  23. # All masked reduction/normalization operations have the same
  24. # signatures. Here we introduce docstring templates that are applied
  25. # to docstrings of reduction/normalization functions via
  26. # _apply_docstring_templates decorator.
  27. def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]:
  28. """Decorator that applies docstring templates to function docstring
  29. and returns the function instance.
  30. """
  31. doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
  32. if doc_string is None:
  33. warnings.warn(
  34. f"No documentation string available for {func.__name__}."
  35. " PyTorch team should run `python tools/update_masked_docs.py`"
  36. " to generate the missing docstrings.",
  37. stacklevel=2,
  38. )
  39. else:
  40. func.__doc__ = doc_string
  41. # Expose function as public symbol
  42. __all__.append(func.__name__)
  43. return func
  44. def _generate_docstring(func):
  45. """A utility function called from tools/update_masked_docs.py
  46. script to update the module torch.masked._docs.py
  47. """
  48. docstring_templates = dict(
  49. reduction_signature="""\
  50. {function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
  51. reduction_descr="""\
  52. Returns {operation name} of all the elements in the :attr:`input`
  53. tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
  54. elements are masked out according to the boolean tensor
  55. :attr:`mask`.""",
  56. reduction_args="""\
  57. If :attr:`keepdim` is ``True``, the output tensor is of the same size
  58. as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
  59. size 1. Otherwise, :attr:`dim` is squeezed (see
  60. :func:`torch.squeeze`), resulting in the output tensor having 1 (or
  61. ``len(dim)``) fewer dimension(s).
  62. The boolean tensor :attr:`mask` defines the "validity" of
  63. :attr:`input` tensor elements: if :attr:`mask` element is True
  64. then the corresponding element in :attr:`input` tensor will be
  65. included in {operation name} computation, otherwise the element is
  66. ignored.
  67. When all elements of :attr:`input` along the given dimension
  68. :attr:`dim` are ignored (fully masked-out), the corresponding element
  69. of the output tensor will have undefined value: it may or may not
  70. correspond to the identity value of {operation name} operation; the
  71. choice may correspond to the value that leads to the most efficient
  72. storage of :attr:`output` tensor.
  73. The mask of the output tensor can be computed as
  74. ``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
  75. dtype=torch.bool)``.
  76. The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
  77. don't need to match, but they must be :ref:`broadcastable
  78. <broadcasting-semantics>` and the dimensionality of the :attr:`mask`
  79. tensor must not be greater than of the :attr:`input` tensor.
  80. Args:
  81. input (Tensor): the input tensor
  82. {args_declarations}
  83. Keyword args:
  84. {kwargs_declarations}""",
  85. reduction_example="""\
  86. Example::
  87. >>> input = {example_input}
  88. >>> input
  89. {indent_example_input}
  90. >>> mask = {example_mask}
  91. >>> mask
  92. {indent_example_mask}
  93. >>> {full_function_name}(input, {example_args}, mask=mask)
  94. {indent_example_output}
  95. """,
  96. reduction_identity="""\
  97. The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
  98. reduction_identity_dtype="""\
  99. The identity value of {operation name} operation, which is used to start the
  100. reduction, depends on input dtype. For instance, for float32, uint8,
  101. and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
  102. normalization_signature="""\
  103. {function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
  104. normalization_descr="""\
  105. Returns {operation name} of all the slices in the :attr:`input` tensor
  106. along :attr:`dim` while the :attr:`input` elements are masked out
  107. according to the boolean tensor :attr:`mask`.
  108. {definition}""",
  109. normalization_args="""\
  110. The boolean tensor :attr:`mask` defines the "validity" of
  111. :attr:`input` tensor elements: if :attr:`mask` element is True then
  112. the corresponding element in :attr:`input` tensor will be included in
  113. {operation name} computation, otherwise the element is ignored.
  114. The values of masked-out elements of the output tensor have undefined
  115. value: it may or may not be set to zero or nan; the choice may correspond to
  116. the value that leads to the most efficient storage of :attr:`output`
  117. tensor.
  118. The mask of the {operation name} output tensor can be computed as
  119. ``torch.broadcast_to(mask, input.shape)``.
  120. The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
  121. don't need to match, but they must be :ref:`broadcastable
  122. <broadcasting-semantics>` and the dimensionality of the :attr:`mask`
  123. tensor must not be greater than of the :attr:`input` tensor.
  124. Args:
  125. input (Tensor): the input tensor
  126. {args_declarations}
  127. Keyword args:
  128. {kwargs_declarations}""",
  129. normalization_example="""\
  130. Example::
  131. >>> input = {example_input}
  132. >>> input
  133. {indent_example_input}
  134. >>> mask = {example_mask}
  135. >>> mask
  136. {indent_example_mask}
  137. >>> {full_function_name}(input, {example_args}, mask=mask)
  138. {indent_example_output}
  139. """,
  140. )
  141. args_and_kwargs = {
  142. # argument name sufficies separated by double underscore will
  143. # be removed in the final documentation string.
  144. "sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  145. "prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  146. "cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
  147. "cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
  148. "amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  149. "amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  150. "argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
  151. "argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
  152. "mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  153. "median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
  154. "norm": (
  155. (
  156. "ord",
  157. "dim",
  158. ),
  159. ("keepdim=False", "dtype=None", "mask=None"),
  160. ),
  161. "var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
  162. "std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
  163. "logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  164. "softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
  165. "log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
  166. "softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
  167. "normalize": (
  168. (
  169. "ord__required",
  170. "dim__as_int",
  171. ),
  172. ("eps=1e-12", "dtype=None", "mask=None"),
  173. ),
  174. }
  175. argument_declarations = {
  176. "dim": """\
  177. dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
  178. Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
  179. "dim__as_int": """\
  180. dim (int): the dimension along which {operation name} is computed.""",
  181. "ord": """\
  182. ord (int, float, optional): the order of vector norm. Default: 2.
  183. See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
  184. "ord__required": """\
  185. ord (int, float): the order of vector norm. Default: 2.
  186. See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
  187. "unbiased": """\
  188. unbiased (bool): when True, use Bessel's correction, otherwise, compute
  189. the uncorrected sample variance.""",
  190. "eps": """\
  191. eps (float, optional): small value to avoid division by zero. Default: {default}.""",
  192. "keepdim": """\
  193. keepdim (bool, optional): whether the output tensor has
  194. :attr:`dim` retained or not. Default: {default}.""",
  195. "dtype": """\
  196. dtype (:class:`torch.dtype`, optional): the desired data type
  197. of returned tensor. If specified, the input tensor is
  198. casted to :attr:`dtype` before the operation is
  199. performed. Default: {default}.""",
  200. "mask": """\
  201. mask (:class:`torch.Tensor`, optional): the boolean tensor
  202. containing the binary mask of validity of input tensor
  203. elements.
  204. Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
  205. }
  206. definitions = {
  207. "softmax": """\
  208. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  209. of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
  210. defined as ``exp(x[i])/sum(exp(x))``.""",
  211. "log_softmax": """\
  212. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  213. of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
  214. defined as ``log(exp(x[i])/sum(exp(x)))``.""",
  215. "softmin": """\
  216. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  217. of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
  218. defined as ``exp(-x[i])/sum(exp(-x))``.""",
  219. "normalize": """\
  220. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  221. of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
  222. defined as ``x[i]/max(norm(x, p), eps)``.""",
  223. "cumsum": """\
  224. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  225. of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
  226. defined as ``sum(x[:i])``.""",
  227. "cumprod": """\
  228. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  229. of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
  230. defined as ``prod(x[:i])``.""",
  231. }
  232. reduction_names = {
  233. "sum": "sum",
  234. "prod": "product",
  235. "amax": "maximum",
  236. "amin": "minimum",
  237. "argmax": "argmax",
  238. "argmin": "argmin",
  239. "mean": "mean",
  240. "median": "median",
  241. "norm": "norm",
  242. "var": "variance",
  243. "std": "standard_deviation",
  244. "logsumexp": "logsumexp",
  245. }
  246. normalization_names = {
  247. "softmax": "softmax",
  248. "log_softmax": "log_softmax",
  249. "softmin": "softmin",
  250. "normalize": "normalize",
  251. "cumsum": "cumulative_sum",
  252. "cumprod": "cumulative_prod",
  253. }
  254. operation_names = {}
  255. operation_names.update(reduction_names)
  256. operation_names.update(normalization_names)
  257. # Default example data:
  258. example_dim = 1
  259. example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
  260. example_mask = torch.tensor([[True, False, True], [False, False, False]])
  261. example_args: tuple[Any, ...]
  262. if func.__name__ in {"norm", "normalize"}:
  263. example_args = (2.0, example_dim)
  264. example_input = example_input.to(dtype=torch.float32)
  265. elif func.__name__ in {"var", "std"}:
  266. example_args = (example_dim, False)
  267. elif func.__name__ == "median":
  268. example_args = (example_dim,)
  269. example_input = example_input.to(dtype=torch.float32)
  270. else:
  271. example_args = (example_dim,)
  272. operation_args: tuple[str, ...]
  273. operation_kwargs: tuple[str, ...]
  274. operation_args, operation_kwargs = args_and_kwargs[func.__name__]
  275. arg_declarations = [
  276. "\n ".join(
  277. argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines()
  278. )
  279. for a in operation_args
  280. ]
  281. kwarg_declarations = [
  282. "\n ".join(
  283. argument_declarations.get(
  284. a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD."
  285. )
  286. .format(default=a.split("=", 1)[1])
  287. .splitlines()
  288. )
  289. for a in operation_kwargs
  290. ]
  291. if func.__name__ in reduction_names:
  292. op_kind = "reduction"
  293. doc_sections = ["signature", "descr", "identity", "args", "example"]
  294. elif func.__name__ in normalization_names:
  295. op_kind = "normalization"
  296. doc_sections = ["signature", "descr", "args", "example"]
  297. example_input = example_input.to(dtype=torch.float32)
  298. else:
  299. # add function name to operation names dictionaries
  300. raise AssertionError(f"unknown function {func.__name__}")
  301. example_output = func(example_input, *example_args, mask=example_mask)
  302. template_data = {
  303. "function_name": func.__name__,
  304. "full_function_name": func.__module__ + "." + func.__name__,
  305. "operation name": operation_names[func.__name__],
  306. "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
  307. "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
  308. # one-line representation of a tensor:
  309. "example_input": " ".join(str(example_input).split()),
  310. "example_args": ", ".join(map(str, example_args)),
  311. "example_mask": " ".join(str(example_mask).split()),
  312. # multi-line representation of a tensor with indent
  313. "indent_example_input": ("\n ").join(str(example_input).splitlines()),
  314. "indent_example_mask": ("\n ").join(str(example_mask).splitlines()),
  315. "indent_example_output": ("\n ").join(str(example_output).splitlines()),
  316. }
  317. if func.__name__ in reduction_names:
  318. template_data.update(
  319. identity_uint8=_reduction_identity(
  320. func.__name__, torch.tensor(0, dtype=torch.uint8)
  321. ),
  322. identity_int32=_reduction_identity(
  323. func.__name__, torch.tensor(0, dtype=torch.int32)
  324. ),
  325. identity_float32=_reduction_identity(
  326. func.__name__, torch.tensor(0, dtype=torch.float32)
  327. ),
  328. )
  329. if func.__name__ == "norm":
  330. template_data.update(
  331. identity_ord_ninf=_reduction_identity(
  332. func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
  333. )
  334. )
  335. elif func.__name__ in normalization_names:
  336. template_data.update(definition=definitions[func.__name__])
  337. else:
  338. # add function name to operation names dictionaries
  339. raise AssertionError(f"unknown function {func.__name__}")
  340. template_data.update(
  341. args_declarations=("\n ".join(arg_declarations)).format_map(template_data)
  342. )
  343. template_data.update(
  344. kwargs_declarations=("\n ".join(kwarg_declarations)).format_map(
  345. template_data
  346. )
  347. )
  348. # Apply function name info to docstring templates:
  349. templates = {
  350. k: v.format_map(template_data)
  351. for k, v in docstring_templates.items()
  352. if k.startswith(op_kind)
  353. }
  354. templates.update(
  355. (k, v.format_map(template_data) if isinstance(v, str) else v)
  356. for k, v in template_data.items()
  357. )
  358. # Apply docstring templates to function doctring:
  359. if func.__doc__ is None:
  360. doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
  361. else:
  362. doc_template = func.__doc__
  363. return doc_template.format_map(templates)
  364. def _reduction_identity(op_name: str, input: Tensor, *args):
  365. """Return identity value as scalar tensor of a reduction operation on
  366. given input, or None, if the identity value cannot be uniquely
  367. defined for the given input.
  368. The identity value of the operation is defined as the initial
  369. value to reduction operation that has a property ``op(op_identity,
  370. value) == value`` for any value in the domain of the operation.
  371. Or put it another way, including or excluding the identity value in
  372. a list of operands will not change the reduction result.
  373. See https://github.com/pytorch/rfcs/pull/27 for more information.
  374. """
  375. dtype: DType = input.dtype
  376. device = input.device
  377. op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present
  378. if op_name in {"sum", "cumsum"}:
  379. return torch.tensor(0, dtype=dtype, device=device)
  380. elif op_name in {"prod", "cumprod"}:
  381. return torch.tensor(1, dtype=dtype, device=device)
  382. elif op_name in {"amax", "argmax", "logaddexp"}:
  383. if torch.is_floating_point(input):
  384. return torch.tensor(-torch.inf, dtype=dtype, device=device)
  385. elif torch.is_signed(input) or dtype == torch.uint8:
  386. return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
  387. elif op_name == "logsumexp":
  388. if torch.is_floating_point(input):
  389. return torch.tensor(-torch.inf, dtype=dtype, device=device)
  390. elif torch.is_complex(input):
  391. return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device)
  392. elif torch.is_signed(input) or dtype == torch.uint8:
  393. return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
  394. elif op_name in {"amin", "argmin"}:
  395. if torch.is_floating_point(input):
  396. return torch.tensor(torch.inf, dtype=dtype, device=device)
  397. elif torch.is_signed(input) or dtype == torch.uint8:
  398. return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
  399. elif op_name == "mean":
  400. # Strictly speaking, the identity value of the mean operation
  401. # is the mean of the input. Since the mean value depends on
  402. # the dim argument and it may be a non-scalar tensor, we
  403. # consider the identity value of the mean operation ambiguous.
  404. # Moreover, the mean value of empty input is undefined.
  405. return None
  406. elif op_name == "norm":
  407. ord = args[0] if args else 2
  408. if ord == float("-inf"):
  409. if not torch.is_floating_point(input):
  410. raise AssertionError(f"input must be floating point, got {input.dtype}")
  411. return torch.tensor(torch.inf, dtype=dtype, device=device)
  412. return torch.tensor(0, dtype=dtype, device=device)
  413. elif op_name == "median":
  414. # We use NaN for now because the implementation is currently using torch.nanmedian
  415. # and NaN is the identity for that function since it gets ignored
  416. dtype = input.dtype if torch.is_floating_point(input) else torch.float
  417. return torch.tensor(torch.nan, dtype=dtype, device=device)
  418. elif op_name in {"var", "std"}:
  419. return None
  420. raise NotImplementedError(f"identity of {op_name} on {dtype} input")
  421. def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
  422. """Return dim argument as a tuple of sorted dim values."""
  423. dims: list[int] = []
  424. if dim == ():
  425. # Currently, `dim=()` in reductions operations means "reduce
  426. # over all dimensions" while in future, it will read "no
  427. # reduce". See https://github.com/pytorch/pytorch/issues/29137
  428. # When gh-29137 is resolved, this if-block must be deleted.
  429. dim = None
  430. if dim is None:
  431. return tuple(range(ndim))
  432. ndim = max(ndim, 1)
  433. dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
  434. for d in dim_:
  435. if d in dims:
  436. raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
  437. if d >= ndim or d < -ndim:
  438. raise IndexError(
  439. f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
  440. )
  441. # pyrefly: ignore [bad-argument-type]
  442. dims.append(d % ndim)
  443. return tuple(sorted(dims))
  444. def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
  445. # Flatted N-D indices to 1-D indices
  446. flat_indices = indices.new_zeros(indices.size(1))
  447. for d, sz in enumerate(shape):
  448. flat_indices.mul_(sz)
  449. flat_indices.add_(indices[d])
  450. return flat_indices
  451. def _any(input: Tensor, dim: tuple, keepdim: bool):
  452. # Support torch.any with tuple dim argument.
  453. # Workaround of https://github.com/pytorch/pytorch/issues/56586
  454. r = input
  455. for d in reversed(dim):
  456. r = r.any(dim=d, keepdim=keepdim)
  457. return r
  458. def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
  459. """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
  460. _sparse_coo_where implements the following invariant:
  461. _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
  462. torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
  463. where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
  464. tensor, and `to_dense(fill_value)` is like `to_dense()` except
  465. that the unspecified elements are mapped to `fill_value` rather
  466. than to `0`.
  467. Returns a sparse COO tensor with the following features:
  468. - all specified elements correspond to masked-in elements that
  469. have the values of the input tensor. If there exists a masked-in
  470. element (as specified by mask) that is not specified in the
  471. input, in the result tensor, the corresponding element has value
  472. 0. In the dense part of the sparse tensor, the masked-out
  473. elements are replaced with fill_value.
  474. - all unspecified elements correspond to masked-out elements.
  475. """
  476. if input.layout != torch.sparse_coo:
  477. raise AssertionError(f"input.layout must be sparse_coo, got {input.layout}")
  478. if mask.layout != input.layout:
  479. raise AssertionError(f"mask.layout must match input.layout, got {mask.layout}")
  480. if mask.shape != input.shape:
  481. raise AssertionError(
  482. f"mask.shape must match input.shape: {mask.shape} vs {input.shape}"
  483. )
  484. if mask.dense_dim() != input.dense_dim():
  485. # TODO: eliminate this restriction
  486. raise AssertionError(
  487. f"mask.dense_dim() must match input.dense_dim(): "
  488. f"{mask.dense_dim()} vs {input.dense_dim()}"
  489. )
  490. input = input.coalesce()
  491. # For set operations on sparse tensor indices, we'll convert
  492. # multi-dimensional indices to 1-D indices for efficiency.
  493. input_flat_indices = _sparse_coo_flatten_indices(
  494. input.indices(), input.shape[: input.sparse_dim()]
  495. )
  496. mask_flat_indices = _sparse_coo_flatten_indices(
  497. mask.indices(), mask.shape[: mask.sparse_dim()]
  498. )
  499. # the set of mask flat indices that define masked-in elements:
  500. if mask.dense_dim() > 0:
  501. mask_values = _any(
  502. mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
  503. )
  504. else:
  505. mask_values = mask.values()
  506. maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
  507. def intersection(i1, i2):
  508. union, counts = torch.cat([i1, i2]).unique(return_counts=True)
  509. return union, torch.where(counts.gt(1))
  510. def minus(i1, i2):
  511. union, counts = torch.cat([i1, i2]).unique(return_counts=True)
  512. return intersection(union[torch.where(counts.eq(1))], i1)
  513. def _apply(a):
  514. obj, w = a
  515. return obj[w]
  516. # the set of input flat indices of specified and masked-in elements:
  517. maskin_input_flat_indices = _apply(
  518. intersection(maskin_flat_indices, input_flat_indices)
  519. )
  520. _, w = intersection(input_flat_indices, maskin_input_flat_indices)
  521. # the indices and values of masked-in elements
  522. where_input_indices = input.indices()[(slice(None),) + w]
  523. where_input_values = input.values()[w]
  524. if mask.dense_dim() > 0:
  525. # apply mask to the dense part of the input values:
  526. _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
  527. where_mask_values = mask.values()[w1]
  528. where_input_values = torch.where(
  529. where_mask_values, where_input_values, fill_value
  530. )
  531. # the set of flat indices of unspecified input and masked-in elements:
  532. maskin_zero_flat_indices = _apply(
  533. minus(maskin_flat_indices, maskin_input_flat_indices)
  534. )
  535. # the indices of masked-in zero elements
  536. _, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
  537. where_zero_indices = mask.indices()[(slice(None),) + w]
  538. # construct result
  539. n = where_zero_indices.size(1)
  540. if n == 0:
  541. # the input is coalesced, hence input_flat_indices are ordered
  542. # and the result is guaranteed to be coalesced:
  543. result = torch.sparse_coo_tensor(
  544. where_input_indices, where_input_values, input.shape
  545. )
  546. return result._coalesced_(True)
  547. where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
  548. where_values = torch.cat(
  549. [
  550. where_input_values,
  551. where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
  552. ]
  553. )
  554. result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
  555. # appending zero elements leads to uncoalesced sparse tensor
  556. return result.coalesce()
  557. def _sparse_coo_scatter_reduction_helper(
  558. op,
  559. mask_input: Tensor,
  560. dims: tuple[int, ...],
  561. keepdim: bool,
  562. dtype: DType | None = None,
  563. ) -> Tensor:
  564. reduce = op.__name__
  565. valid_reductions = ["sum", "prod", "amax", "amin"]
  566. if reduce not in valid_reductions:
  567. raise ValueError(
  568. f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
  569. )
  570. output_dtype = dtype
  571. values, indices = mask_input._values(), mask_input._indices()
  572. input_dims = mask_input.dim()
  573. num_sparse_dims = mask_input.sparse_dim()
  574. reduced_sparse_dims = []
  575. retained_sparse_dims = []
  576. reduced_dense_dims = []
  577. # promote dtype if specified
  578. if values.dtype != output_dtype:
  579. values = values.to(output_dtype)
  580. if keepdim:
  581. output_shape = tuple(
  582. 1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
  583. )
  584. else:
  585. output_shape = tuple(
  586. si for (i, si) in enumerate(mask_input.shape) if i not in dims
  587. )
  588. for d in dims:
  589. if d >= input_dims:
  590. continue
  591. if d < num_sparse_dims:
  592. reduced_sparse_dims.append(d)
  593. else:
  594. reduced_dense_dims.append(d + 1 - num_sparse_dims)
  595. # Reduce dense dimensions
  596. if len(reduced_dense_dims) > 0:
  597. if reduce == "sum":
  598. new_values = values
  599. new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
  600. else:
  601. # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
  602. return NotImplemented
  603. else:
  604. new_values = values.clone()
  605. # Reduce sparse dimensions
  606. if len(reduced_sparse_dims) == num_sparse_dims:
  607. if reduce in {"amax", "amin"} and new_values.size(0) == 0:
  608. # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
  609. # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
  610. # See https://github.com/pytorch/pytorch/issues/61901
  611. new_values = _reduction_identity(reduce, new_values)
  612. else:
  613. new_values = op(new_values, dim=0)
  614. if keepdim:
  615. for _ in range(num_sparse_dims):
  616. new_values = new_values.unsqueeze(0)
  617. return new_values.to(dtype=output_dtype).to_sparse()
  618. else:
  619. new_indices = indices.clone()
  620. if keepdim:
  621. # zero out reduced sparse dimensions if keepdim = True
  622. # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
  623. new_indices[reduced_sparse_dims, :] = 0
  624. else:
  625. # remove reduced sparse dimensions if keepdim = False
  626. if len(reduced_sparse_dims) > 0:
  627. retained_sparse_dims = [
  628. i
  629. for i in range(num_sparse_dims)
  630. if i not in set(reduced_sparse_dims)
  631. ]
  632. new_indices = new_indices.index_select(
  633. 0, torch.tensor(retained_sparse_dims).to(mask_input.device)
  634. )
  635. # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
  636. if new_indices.numel() > 0:
  637. # lexsort indices and get index tensor for scatter reduction
  638. new_indices, inverse_indices = torch.unique(
  639. new_indices, return_inverse=True, dim=1
  640. )
  641. out_shape = list(new_values.shape)
  642. out_shape[0] = new_indices.shape[1]
  643. for _ in range(new_values.ndim - 1):
  644. inverse_indices = inverse_indices.unsqueeze(-1)
  645. scatter_indices = inverse_indices.expand(new_values.shape)
  646. # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
  647. if output_dtype in {torch.bfloat16, torch.float16}:
  648. new_values = new_values.to(torch.float)
  649. out = new_values.new_empty(out_shape)
  650. new_values = out.scatter_reduce_(
  651. 0, scatter_indices, new_values, reduce=reduce, include_self=False
  652. )
  653. new_values = new_values.to(dtype=output_dtype)
  654. else:
  655. out = new_values.new_empty(out_shape)
  656. new_values = out.scatter_reduce_(
  657. 0, scatter_indices, new_values, reduce=reduce, include_self=False
  658. )
  659. return torch.sparse_coo_tensor(
  660. new_indices,
  661. new_values,
  662. output_shape,
  663. dtype=output_dtype,
  664. device=mask_input.device,
  665. )
  666. def _sparse_csr_segment_reduction_helper(
  667. op,
  668. mask_input: Tensor,
  669. dims: tuple[int, ...],
  670. keepdim: bool,
  671. dtype: DType | None = None,
  672. ) -> Tensor:
  673. # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
  674. # FIXME: when dense dimensions are implemented for CSR tensors
  675. if not keepdim:
  676. raise AssertionError(
  677. "reduction operations on CSR tensors with keepdim=False is unsupported"
  678. )
  679. reduce = op.__name__
  680. valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
  681. if reduce not in valid_reductions:
  682. raise ValueError(
  683. f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
  684. )
  685. device = mask_input.device
  686. output_dtype = dtype
  687. values, crow_indices, col_indices = (
  688. mask_input.values(),
  689. mask_input.crow_indices(),
  690. mask_input.col_indices(),
  691. )
  692. # promote dtype if specified
  693. if values.dtype != output_dtype:
  694. values = values.to(output_dtype)
  695. if len(dims) == 0:
  696. return mask_input
  697. if len(dims) == 1:
  698. if dims[0] == 0:
  699. new_col_indices, scatter_indices = torch.unique(
  700. col_indices, return_inverse=True
  701. )
  702. new_nnz = new_col_indices.shape[0]
  703. new_crow_indices = torch.tensor([0, new_nnz])
  704. new_values = values.new_empty(new_col_indices.shape)
  705. new_values.scatter_reduce_(
  706. 0, scatter_indices, values, reduce, include_self=False
  707. )
  708. new_shape = [1, mask_input.size(1)]
  709. else:
  710. if dims[0] != 1:
  711. raise AssertionError(
  712. "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
  713. )
  714. # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
  715. # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
  716. new_crow_indices = torch.cat(
  717. (
  718. crow_indices.new_zeros(1),
  719. torch.cumsum(torch.diff(crow_indices) != 0, 0),
  720. ),
  721. 0,
  722. )
  723. new_nnz = new_crow_indices[-1]
  724. new_col_indices = col_indices.new_zeros(new_nnz) # type: ignore[call-overload]
  725. new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
  726. new_shape = [mask_input.size(0), 1]
  727. else:
  728. if len(dims) != 2:
  729. raise AssertionError(f"expected len(dims) == 2, got {len(dims)}")
  730. nnz = min(1, values.numel())
  731. if nnz == 1:
  732. op_kwargs = {"keepdim": True, "dtype": output_dtype}
  733. # amax and amin do not support dtype kwarg
  734. if reduce in ["amax", "amin"]:
  735. del op_kwargs["dtype"]
  736. new_values = op(values, 0, **op_kwargs)
  737. else:
  738. new_values = torch.empty(0, dtype=output_dtype)
  739. new_col_indices = col_indices.new_zeros(nnz)
  740. new_crow_indices = torch.tensor([0, nnz])
  741. new_shape = [1, nnz]
  742. return torch.sparse_csr_tensor(
  743. new_crow_indices,
  744. new_col_indices,
  745. new_values,
  746. new_shape,
  747. dtype=output_dtype,
  748. device=device,
  749. )
  750. def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
  751. """Sparse variant of torch.where. Supports sparse CSR tensors."""
  752. # TODO: implement sparse CSR specific where operator for efficiency
  753. return _sparse_coo_where(
  754. mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
  755. ).to_sparse_csr()
  756. def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
  757. """torch.where with sparse inputs support.
  758. _where implements the following invariant:
  759. _where(mask, input, fill_value).to_dense(fill_value) ==
  760. torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
  761. where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
  762. tensor, and `to_dense(fill_value)` is like `to_dense()` except
  763. that the unspecified elements are mapped to `fill_value` rather
  764. than to `0`.
  765. Returns a sparse tensor with the following features:
  766. - all specified elements correspond to masked-in elements that
  767. have the values of the input tensor. If there exists a masked-in
  768. element (as specified by mask) that is not specified in the
  769. input, in the result tensor, the corresponding element has value
  770. 0. In the dense part of the sparse tensor, the masked-out
  771. elements are replaced with fill_value.
  772. - all unspecified elements correspond to masked-out elements.
  773. """
  774. if mask.layout == torch.strided:
  775. return torch.where(mask, input, fill_value)
  776. elif mask.layout == torch.sparse_coo:
  777. return _sparse_coo_where(mask, input, fill_value)
  778. elif mask.layout == torch.sparse_csr:
  779. return _sparse_csr_where(mask, input, fill_value)
  780. else:
  781. raise ValueError(
  782. f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
  783. )
  784. def _input_mask(input: Tensor | MaskedTensor, *args, **kwargs) -> Tensor:
  785. """Return canonical input mask.
  786. A canonical input mask is defined as a boolean mask tensor that
  787. shape and layout matches with the shape and the layout of the
  788. input.
  789. The canonical input mask is computed from the :attr:`mask` tensor
  790. content to meet the following criteria:
  791. 1. The shape of the canonical input mask is the same as the shape
  792. of :attr:`input` tensor. If the mask tensor has a smaller shape
  793. than the shape of the :attr:`input`, broadcasting rules will be
  794. applied. Downcasting of mask is not supported.
  795. 2. The layout of the canonical input mask is the same as the
  796. layout of the :attr:`input` tensor. If the mask has different
  797. layout, it will be converted to the expected layout. In the
  798. case of sparse COO layout, the canonical input mask will be
  799. coalesced.
  800. 3. The dtype of the canonical input mask is torch.bool. If the
  801. mask dtype is not bool then it will be converted to bool dtype
  802. using `.to(dtype=bool)` method call.
  803. 4. The elements of the canonical input mask have boolean values
  804. copied from the content of the :attr:`mask` tensor (after
  805. possible broadcasting and dtype conversion transforms). In
  806. general, the sparsity pattern of the sparse canonical input
  807. mask need not to be the same as the sparsity pattern of the
  808. sparse :attr:`input` tensor.
  809. """
  810. if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
  811. raise ValueError(
  812. f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
  813. )
  814. mask = kwargs.get("mask")
  815. # default mask
  816. if mask is None:
  817. raise ValueError("_input_mask requires explicit mask")
  818. # mask shape must match with input shape
  819. if mask.shape != input.shape:
  820. if mask.ndim > input.ndim:
  821. raise IndexError(
  822. "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
  823. )
  824. if mask.layout == torch.strided:
  825. mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
  826. elif mask.layout == torch.sparse_coo:
  827. mask = torch._sparse_broadcast_to(mask, input.shape)
  828. else:
  829. if mask.layout != torch.sparse_csr:
  830. raise AssertionError(f"expected sparse_csr layout, got {mask.layout}")
  831. # Broadcasting of CSR tensors is not implemented. Working
  832. # around by using COO layout.
  833. mask = torch._sparse_broadcast_to(
  834. mask.to_sparse(), input.shape
  835. ).to_sparse_csr()
  836. # mask layout must match with input layout
  837. if mask.layout != input.layout:
  838. if input.layout == torch.strided:
  839. mask = mask.to_dense()
  840. elif input.layout == torch.sparse_coo:
  841. if mask.layout == torch.strided:
  842. mask = mask.to_sparse(input.sparse_dim())
  843. else:
  844. mask = mask.to_sparse()
  845. else:
  846. if input.layout != torch.sparse_csr:
  847. raise AssertionError(f"expected sparse_csr layout, got {input.layout}")
  848. mask = mask.to_sparse_csr()
  849. # sparse mask must be coalesced
  850. if mask.layout == torch.sparse_coo:
  851. mask = mask.coalesce()
  852. # mask is a boolean tensor
  853. mask = mask.to(dtype=torch.bool)
  854. return mask
  855. def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
  856. """Return output mask of masked operation applied to given arguments."""
  857. if callable(op):
  858. is_reduction = op.__name__ in {
  859. "sum",
  860. "prod",
  861. "amax",
  862. "amin",
  863. "argmax",
  864. "argmin",
  865. "mean",
  866. "median",
  867. "norm",
  868. "var",
  869. "std",
  870. "logsumexp",
  871. }
  872. is_normalization = op.__name__ in {
  873. "softmax",
  874. "log_softmax",
  875. "softmin",
  876. "normalize",
  877. "cumsum",
  878. "cumprod",
  879. }
  880. if is_reduction:
  881. if op.__name__ == "norm":
  882. if args:
  883. args = args[1:] # lstrip ord argument
  884. dim = args[0] if args else kwargs.get("dim")
  885. outmask = _input_mask(input, *args, **kwargs)
  886. keepdim = kwargs.get("keepdim", False)
  887. dim_ = _canonical_dim(dim, input.ndim)
  888. return _any(outmask, dim_, bool(keepdim))
  889. elif is_normalization:
  890. return _input_mask(input, *args, **kwargs)
  891. else:
  892. raise ValueError(
  893. f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
  894. )
  895. else:
  896. raise ValueError(
  897. f"_output_mask expected masked operation (got {type(op).__name__} object)"
  898. )
  899. def _combine_input_and_mask(op, input: MaskedTensor | Tensor, mask, *args) -> Tensor:
  900. def helper(input, mask):
  901. if mask is None:
  902. return input
  903. canonical_mask = _input_mask(input, mask=mask)
  904. if callable(op):
  905. fill_value = _reduction_identity(op.__name__, input, *args)
  906. return _where(canonical_mask, input, fill_value)
  907. else:
  908. raise ValueError(
  909. f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
  910. )
  911. class Combine(torch.autograd.Function):
  912. @staticmethod
  913. # pyrefly: ignore [bad-override]
  914. def forward(ctx, input, mask):
  915. """Return input with masked-out elements eliminated for the given operations."""
  916. ctx.save_for_backward(mask)
  917. if mask is not None:
  918. ctx.mark_non_differentiable(mask)
  919. return helper(input, mask)
  920. @staticmethod
  921. # pyrefly: ignore [bad-override]
  922. def backward(ctx, grad_output):
  923. (mask,) = ctx.saved_tensors
  924. grad_data = (
  925. grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
  926. )
  927. result = as_masked_tensor(grad_data, mask)
  928. return result, None
  929. return (
  930. Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr]
  931. if is_masked_tensor(input)
  932. else helper(input, mask)
  933. )
  934. @_apply_docstring_templates
  935. def sum(
  936. input: Tensor | MaskedTensor,
  937. dim: DimOrDims = None,
  938. *,
  939. keepdim: bool | None = False,
  940. dtype: DType | None = None,
  941. mask: Tensor | None = None,
  942. ) -> Tensor:
  943. # __doc__ is generated by _apply_docstring_templates decorator
  944. if dtype is None:
  945. # promote integer types to int64 when output dtype is not specified
  946. if input.layout == torch.sparse_csr:
  947. if input.dtype in {
  948. torch.uint8,
  949. torch.bool,
  950. torch.int8,
  951. torch.int16,
  952. torch.int32,
  953. }:
  954. # csr.to(dtype=torch.int64) is not implemented, so
  955. # using coo.to on input to ensure the promoted dtype
  956. input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
  957. else:
  958. dtype = input.dtype
  959. else:
  960. dtype = input.dtype
  961. if input.dtype in {
  962. torch.uint8,
  963. torch.bool,
  964. torch.int8,
  965. torch.int16,
  966. torch.int32,
  967. }:
  968. dtype = torch.int64
  969. dim_ = _canonical_dim(dim, input.ndim)
  970. mask_input = _combine_input_and_mask(sum, input, mask)
  971. if mask_input.layout == torch.strided:
  972. return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
  973. elif mask_input.layout == torch.sparse_coo:
  974. return _sparse_coo_scatter_reduction_helper(
  975. torch.sum, mask_input, dim_, bool(keepdim), dtype
  976. )
  977. elif mask_input.layout == torch.sparse_csr:
  978. return torch._sparse_csr_sum(
  979. mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
  980. )
  981. else:
  982. raise ValueError(
  983. f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  984. )
  985. @_apply_docstring_templates
  986. def prod(
  987. input: Tensor | MaskedTensor,
  988. dim: DimOrDims = None,
  989. *,
  990. keepdim: bool | None = False,
  991. dtype: DType | None = None,
  992. mask: Tensor | None = None,
  993. ) -> Tensor:
  994. # __doc__ is generated by _apply_docstring_templates decorator
  995. if dtype is None:
  996. # promote integer types to int64 when output dtype is not specified
  997. if input.layout == torch.sparse_csr:
  998. if input.dtype in {
  999. torch.uint8,
  1000. torch.bool,
  1001. torch.int8,
  1002. torch.int16,
  1003. torch.int32,
  1004. }:
  1005. # csr.to(dtype=torch.int64) is not implemented, so
  1006. # using coo.to on input to ensure the promoted dtype
  1007. input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
  1008. else:
  1009. dtype = input.dtype
  1010. else:
  1011. dtype = input.dtype
  1012. if input.dtype in {
  1013. torch.uint8,
  1014. torch.bool,
  1015. torch.int8,
  1016. torch.int16,
  1017. torch.int32,
  1018. }:
  1019. dtype = torch.int64
  1020. dim_ = _canonical_dim(dim, input.ndim)
  1021. mask_input = _combine_input_and_mask(prod, input, mask)
  1022. if mask_input.layout == torch.strided:
  1023. # Workaround https://github.com/pytorch/pytorch/issues/56586
  1024. result = mask_input
  1025. result = result.to(dtype=dtype)
  1026. for d in reversed(dim_):
  1027. result = result.prod(dim=d, keepdim=bool(keepdim))
  1028. return result
  1029. elif mask_input.layout == torch.sparse_coo:
  1030. if mask is None:
  1031. # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
  1032. raise ValueError(
  1033. "masked prod expects explicit mask for sparse_coo tensor input"
  1034. )
  1035. return _sparse_coo_scatter_reduction_helper(
  1036. torch.prod, mask_input, dim_, bool(keepdim), dtype
  1037. )
  1038. elif mask_input.layout == torch.sparse_csr:
  1039. if mask is None:
  1040. # mask is None corresponds to all-True mask. The
  1041. # unspecified elements in the CSR tensor correspond to
  1042. # zero values. Hence, the prod reduction result is
  1043. # automatically zero unless all elements are specified.
  1044. # A semi-optimal way to take this into account is to use:
  1045. #
  1046. # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
  1047. #
  1048. # but that requires implementing `all` and `nonzero`
  1049. # support for sparse csr tensors.
  1050. raise ValueError(
  1051. "masked prod expects explicit mask for sparse_csr tensor input"
  1052. )
  1053. return torch._sparse_csr_prod(
  1054. mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
  1055. )
  1056. else:
  1057. raise ValueError(
  1058. f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  1059. )
  1060. @_apply_docstring_templates
  1061. def cumsum(
  1062. input: Tensor,
  1063. dim: int,
  1064. *,
  1065. dtype: DType | None = None,
  1066. mask: Tensor | None = None,
  1067. ) -> Tensor:
  1068. if dtype is None:
  1069. dtype = input.dtype
  1070. dim_ = _canonical_dim(dim, input.ndim)[0]
  1071. mask_input = _combine_input_and_mask(sum, input, mask)
  1072. if mask_input.layout == torch.strided:
  1073. return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
  1074. else:
  1075. raise ValueError(
  1076. f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
  1077. )
  1078. @_apply_docstring_templates
  1079. def cumprod(
  1080. input: Tensor,
  1081. dim: int,
  1082. *,
  1083. dtype: DType | None = None,
  1084. mask: Tensor | None = None,
  1085. ) -> Tensor:
  1086. if dtype is None:
  1087. dtype = input.dtype
  1088. dim_ = _canonical_dim(dim, input.ndim)[0]
  1089. mask_input = _combine_input_and_mask(prod, input, mask)
  1090. if mask_input.layout == torch.strided:
  1091. return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
  1092. else:
  1093. raise ValueError(
  1094. f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
  1095. )
  1096. @_apply_docstring_templates
  1097. def amax(
  1098. input: Tensor | MaskedTensor,
  1099. dim: DimOrDims = None,
  1100. *,
  1101. keepdim: bool | None = False,
  1102. dtype: DType | None = None,
  1103. mask: Tensor | None = None,
  1104. ) -> Tensor:
  1105. """\
  1106. {reduction_signature}
  1107. {reduction_descr}
  1108. {reduction_identity_dtype}
  1109. {reduction_args}
  1110. {reduction_example}"""
  1111. if dtype is None:
  1112. dtype = input.dtype
  1113. mask_input = _combine_input_and_mask(amax, input, mask)
  1114. dim_ = _canonical_dim(dim, mask_input.ndim)
  1115. if mask_input.layout == torch.strided:
  1116. return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
  1117. elif mask_input.layout == torch.sparse_coo:
  1118. if mask is None:
  1119. # See comment in the sparse_csr branch of prod, a similar issue arises here
  1120. # where unspecified elements along a dimension may need to be reduced with the result
  1121. raise ValueError(
  1122. "masked amax expects explicit mask for sparse_coo tensor input"
  1123. )
  1124. return _sparse_coo_scatter_reduction_helper(
  1125. torch.amax, mask_input, dim_, bool(keepdim), dtype
  1126. )
  1127. elif mask_input.layout == torch.sparse_csr:
  1128. if mask is None:
  1129. raise ValueError(
  1130. "masked amax expects explicit mask for sparse_csr tensor input"
  1131. )
  1132. return _sparse_csr_segment_reduction_helper(
  1133. torch.amax, mask_input, dim_, bool(keepdim), dtype
  1134. )
  1135. else:
  1136. raise ValueError(
  1137. f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  1138. )
  1139. @_apply_docstring_templates
  1140. def amin(
  1141. input: Tensor | MaskedTensor,
  1142. dim: DimOrDims = None,
  1143. *,
  1144. keepdim: bool | None = False,
  1145. dtype: DType | None = None,
  1146. mask: Tensor | None = None,
  1147. ) -> Tensor:
  1148. """\
  1149. {reduction_signature}
  1150. {reduction_descr}
  1151. {reduction_identity_dtype}
  1152. {reduction_args}
  1153. {reduction_example}"""
  1154. if dtype is None:
  1155. dtype = input.dtype
  1156. mask_input = _combine_input_and_mask(amin, input, mask)
  1157. dim_ = _canonical_dim(dim, mask_input.ndim)
  1158. if mask_input.layout == torch.strided:
  1159. return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
  1160. elif mask_input.layout == torch.sparse_coo:
  1161. if mask is None:
  1162. # See comment in the sparse_csr branch of prod, a similar issue arises here
  1163. # where unspecified elements along a dimension may need to be reduced with the result
  1164. raise ValueError(
  1165. "masked amax expects explicit mask for sparse_coo tensor input"
  1166. )
  1167. return _sparse_coo_scatter_reduction_helper(
  1168. torch.amin, mask_input, dim_, bool(keepdim), dtype
  1169. )
  1170. elif mask_input.layout == torch.sparse_csr:
  1171. if mask is None:
  1172. raise ValueError(
  1173. "masked amin expects explicit mask for sparse_csr tensor input"
  1174. )
  1175. return _sparse_csr_segment_reduction_helper(
  1176. torch.amin, mask_input, dim_, bool(keepdim), dtype
  1177. )
  1178. else:
  1179. raise ValueError(
  1180. f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  1181. )
  1182. @_apply_docstring_templates
  1183. def argmax(
  1184. input: Tensor | MaskedTensor,
  1185. dim: int | None = None,
  1186. *,
  1187. keepdim: bool | None = False,
  1188. dtype: DType | None = None,
  1189. mask: Tensor | None = None,
  1190. ) -> Tensor:
  1191. """\
  1192. {reduction_signature}
  1193. {reduction_descr}
  1194. {reduction_identity_dtype}
  1195. {reduction_args}
  1196. {reduction_example}"""
  1197. if dtype is None:
  1198. dtype = input.dtype
  1199. mask_input = _combine_input_and_mask(argmax, input, mask)
  1200. if mask_input.layout == torch.strided:
  1201. return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
  1202. else:
  1203. raise ValueError(
  1204. f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
  1205. )
  1206. @_apply_docstring_templates
  1207. def argmin(
  1208. input: Tensor | MaskedTensor,
  1209. dim: int | None = None,
  1210. *,
  1211. keepdim: bool | None = False,
  1212. dtype: DType | None = None,
  1213. mask: Tensor | None = None,
  1214. ) -> Tensor:
  1215. """\
  1216. {reduction_signature}
  1217. {reduction_descr}
  1218. {reduction_identity_dtype}
  1219. {reduction_args}
  1220. {reduction_example}"""
  1221. if dtype is None:
  1222. dtype = input.dtype
  1223. mask_input = _combine_input_and_mask(argmin, input, mask)
  1224. if mask_input.layout == torch.strided:
  1225. return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
  1226. else:
  1227. raise ValueError(
  1228. f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
  1229. )
  1230. @_apply_docstring_templates
  1231. def mean(
  1232. input: Tensor | MaskedTensor,
  1233. dim: DimOrDims = None,
  1234. *,
  1235. keepdim: bool | None = False,
  1236. dtype: DType | None = None,
  1237. mask: Tensor | None = None,
  1238. ) -> Tensor:
  1239. """\
  1240. {reduction_signature}
  1241. {reduction_descr}
  1242. By definition, the identity value of a mean operation is the mean
  1243. value of the tensor. If all elements of the input tensor along given
  1244. dimension(s) :attr:`dim` are masked-out, the identity value of the
  1245. mean is undefined. Due to this ambiguity, the elements of output
  1246. tensor with strided layout, that correspond to fully masked-out
  1247. elements, have ``nan`` values.
  1248. {reduction_args}
  1249. {reduction_example}"""
  1250. dtype_source = "Optional"
  1251. if dtype is None:
  1252. dtype = input.dtype
  1253. dtype_source = "Input"
  1254. if not (dtype.is_floating_point or dtype.is_complex):
  1255. raise ValueError(
  1256. f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
  1257. f"a floating point or complex dtype. Got: {dtype}"
  1258. )
  1259. if input.layout == torch.strided:
  1260. if mask is None:
  1261. # TODO: compute count analytically
  1262. # pyrefly: ignore [no-matching-overload]
  1263. count = sum(
  1264. torch.ones(input.shape, dtype=torch.int64, device=input.device),
  1265. dim,
  1266. keepdim=keepdim,
  1267. )
  1268. # pyrefly: ignore [no-matching-overload]
  1269. total = sum(input, dim, keepdim=keepdim, dtype=dtype)
  1270. else:
  1271. inmask = _input_mask(input, mask=mask)
  1272. count = inmask.sum(dim=dim, keepdim=bool(keepdim))
  1273. # pyrefly: ignore [no-matching-overload]
  1274. total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
  1275. return total / count
  1276. elif input.layout == torch.sparse_csr:
  1277. mask_input = _combine_input_and_mask(mean, input, mask)
  1278. dim_ = _canonical_dim(dim, mask_input.ndim)
  1279. if mask is None:
  1280. raise ValueError(
  1281. "masked mean expects explicit mask for sparse_csr tensor input"
  1282. )
  1283. return _sparse_csr_segment_reduction_helper(
  1284. torch.mean, mask_input, dim_, bool(keepdim), dtype
  1285. )
  1286. else:
  1287. raise ValueError(
  1288. f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
  1289. )
  1290. @_apply_docstring_templates
  1291. def median(
  1292. input: Tensor | MaskedTensor,
  1293. dim: int = -1,
  1294. *,
  1295. keepdim: bool = False,
  1296. dtype: DType | None = None,
  1297. mask: Tensor | None = None,
  1298. ) -> Tensor:
  1299. """\
  1300. {reduction_signature}
  1301. {reduction_descr}
  1302. By definition, the identity value of a median operation is the median
  1303. value of the tensor. If all elements of the input tensor along given
  1304. dimension(s) :attr:`dim` are masked-out, the identity value of the
  1305. median is undefined. Due to this ambiguity, the elements of output
  1306. tensor with strided layout, that correspond to fully masked-out
  1307. elements, have ``nan`` values.
  1308. {reduction_args}
  1309. {reduction_example}"""
  1310. if dtype is None:
  1311. dtype = input.dtype
  1312. dim_ = _canonical_dim(dim, input.ndim)[0]
  1313. is_float = torch.is_floating_point(input)
  1314. if not is_float:
  1315. input = input.to(dtype=torch.float)
  1316. mask_input = _combine_input_and_mask(median, input, mask)
  1317. if mask_input.layout == torch.strided:
  1318. output = torch.nanmedian(mask_input, dim_, keepdim).values
  1319. if is_float:
  1320. return output
  1321. elif not is_float and not torch.isnan(output).any():
  1322. return output.to(dtype=dtype)
  1323. else:
  1324. raise ValueError(
  1325. "masked median expects no fully masked out rows if dtype is not floating point"
  1326. )
  1327. else:
  1328. raise ValueError(
  1329. f"masked median expects strided tensor (got {mask_input.layout} tensor)"
  1330. )
  1331. @_apply_docstring_templates
  1332. def logsumexp(
  1333. input: Tensor,
  1334. dim: DimOrDims = None,
  1335. *,
  1336. keepdim: bool = False,
  1337. dtype: DType | None = None,
  1338. mask: Tensor | None = None,
  1339. ) -> Tensor:
  1340. if dtype is None:
  1341. dtype = input.dtype
  1342. dim_ = _canonical_dim(dim, input.ndim)
  1343. mask_input = _combine_input_and_mask(logsumexp, input, mask)
  1344. if mask_input.layout == torch.strided:
  1345. return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
  1346. else:
  1347. raise ValueError(
  1348. f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
  1349. )
  1350. # Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
  1351. def logaddexp(
  1352. input: Tensor | MaskedTensor,
  1353. other: Tensor | MaskedTensor,
  1354. *,
  1355. dtype: DType | None = None,
  1356. input_mask: Tensor | None = None,
  1357. other_mask: Tensor | None = None,
  1358. ) -> Tensor:
  1359. """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
  1360. Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
  1361. tensor. The :attr:`input` elements are masked out according to the boolean tensor
  1362. :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
  1363. :attr:`other_mask`.
  1364. The shapes of a mask tensor and the tensor to be masked
  1365. don't need to match, but they must be :ref:`broadcastable
  1366. <broadcasting-semantics>` and the dimensionality of the mask
  1367. tensor must not be greater than of the tensor to be masked.
  1368. Args:
  1369. input (Tensor): the input tensor
  1370. other (Tensor): the second input tensor
  1371. Keyword args:
  1372. dtype (:class:`torch.dtype`, optional): the desired data type
  1373. of returned tensor. If specified, the output tensor is
  1374. casted to :attr:`dtype` after the operation is
  1375. performed. Default: None.
  1376. input_mask (:class:`torch.Tensor`, optional): the boolean tensor
  1377. containing the binary mask of validity of :attr:`input` tensor elements.
  1378. Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
  1379. other_mask (:class:`torch.Tensor`, optional): the boolean tensor
  1380. containing the binary mask of validity of :attr:`other` tensor elements.
  1381. Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
  1382. Example::
  1383. >>> input = torch.tensor([-100.0, -200, -300])
  1384. >>> input
  1385. tensor([-100., -200., -300.])
  1386. >>> other = torch.tensor([-1.0, -2, -3])
  1387. >>> other
  1388. tensor([-1., -2., -3.])
  1389. >>> mask = torch.tensor([True, False, True])
  1390. >>> mask
  1391. tensor([ True, False, True])
  1392. >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
  1393. tensor([-1., -inf, -3.])"""
  1394. if dtype is None:
  1395. dtype = input.dtype
  1396. if input.layout == torch.strided and other.layout == torch.strided:
  1397. mask_input = _combine_input_and_mask(logaddexp, input, input_mask)
  1398. mask_other = _combine_input_and_mask(logaddexp, other, other_mask)
  1399. return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
  1400. else:
  1401. raise ValueError(
  1402. f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
  1403. )
  1404. @_apply_docstring_templates
  1405. def norm(
  1406. input: Tensor | MaskedTensor,
  1407. ord: float | None = 2.0,
  1408. dim: DimOrDims = None,
  1409. *,
  1410. keepdim: bool | None = False,
  1411. dtype: DType | None = None,
  1412. mask: Tensor | None = None,
  1413. ) -> Tensor:
  1414. """\
  1415. {reduction_signature}
  1416. {reduction_descr}
  1417. The identity value of norm operation, which is used to start the
  1418. reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
  1419. ``{identity_ord_ninf}``.
  1420. {reduction_args}
  1421. {reduction_example}"""
  1422. if dtype is None:
  1423. dtype = input.dtype
  1424. mask_input = _combine_input_and_mask(norm, input, mask, ord)
  1425. if mask_input.layout == torch.strided:
  1426. dim_ = _canonical_dim(dim, input.ndim)
  1427. return torch.linalg.vector_norm(
  1428. mask_input, ord, dim_, bool(keepdim), dtype=dtype
  1429. )
  1430. else:
  1431. raise ValueError(
  1432. f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
  1433. )
  1434. def _std_var(
  1435. input: Tensor | MaskedTensor,
  1436. dim: DimOrDims,
  1437. unbiased: bool | None,
  1438. *,
  1439. correction_opt: int | float | None,
  1440. keepdim: bool | None,
  1441. dtype: DType | None,
  1442. mask: Tensor | None,
  1443. take_sqrt: bool | None,
  1444. ) -> Tensor:
  1445. if unbiased is not None and correction_opt is not None:
  1446. raise AssertionError("Only one of unbiased and correction may be given")
  1447. correction = 1.0
  1448. if unbiased is not None:
  1449. correction = 1.0 if unbiased else 0.0
  1450. if correction_opt is not None:
  1451. correction = sym_float(correction_opt)
  1452. if dtype is None:
  1453. dtype = input.dtype
  1454. if not (dtype.is_floating_point or dtype.is_complex):
  1455. dtype = torch.float32
  1456. compute_dtype = dtype
  1457. if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
  1458. compute_dtype = torch.float32
  1459. if input.layout == torch.strided:
  1460. if mask is None:
  1461. # TODO: compute count analytically
  1462. # pyrefly: ignore [no-matching-overload]
  1463. count = sum(
  1464. torch.ones(input.shape, dtype=torch.int64, device=input.device),
  1465. dim,
  1466. keepdim=True,
  1467. )
  1468. # pyrefly: ignore [no-matching-overload]
  1469. sample_total = sum(input, dim, keepdim=True, dtype=dtype)
  1470. else:
  1471. inmask = _input_mask(input, mask=mask)
  1472. count = inmask.sum(dim=dim, keepdim=True)
  1473. # pyrefly: ignore [no-matching-overload]
  1474. sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
  1475. # TODO: replace torch.subtract/divide/square/maximum with
  1476. # masked subtract/divide/square/maximum when these will be
  1477. # available.
  1478. sample_mean = torch.divide(sample_total, count)
  1479. x = torch.subtract(input, sample_mean)
  1480. if mask is None:
  1481. # pyrefly: ignore [no-matching-overload]
  1482. total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
  1483. else:
  1484. # pyrefly: ignore [no-matching-overload]
  1485. total = sum(
  1486. x * x.conj(),
  1487. dim,
  1488. keepdim=keepdim,
  1489. dtype=compute_dtype,
  1490. mask=inmask, # type: ignore[possibly-undefined]
  1491. )
  1492. if not keepdim:
  1493. count = count.reshape(total.shape)
  1494. if correction != 0:
  1495. real_dtype = (
  1496. corresponding_real_dtype(compute_dtype)
  1497. if compute_dtype.is_complex
  1498. else compute_dtype
  1499. )
  1500. count = count.to(real_dtype)
  1501. count = torch.subtract(count, correction)
  1502. count = torch.maximum(count, count.new_zeros([]))
  1503. output = torch.divide(total, count).to(dtype=dtype)
  1504. if take_sqrt:
  1505. output = torch.sqrt(output)
  1506. return output
  1507. else:
  1508. raise ValueError(
  1509. f"masked std/var expects strided tensor (got {input.layout} tensor)"
  1510. )
  1511. @_apply_docstring_templates
  1512. def var(
  1513. input: Tensor | MaskedTensor,
  1514. dim: DimOrDims = None,
  1515. unbiased: bool | None = None,
  1516. *,
  1517. correction: int | float | None = None,
  1518. keepdim: bool | None = False,
  1519. dtype: DType | None = None,
  1520. mask: Tensor | None = None,
  1521. ) -> Tensor:
  1522. """\
  1523. {reduction_signature}
  1524. {reduction_descr}
  1525. The identity value of sample variance operation is undefined. The
  1526. elements of output tensor with strided layout, that correspond to
  1527. fully masked-out elements, have ``nan`` values.
  1528. {reduction_args}
  1529. {reduction_example}"""
  1530. return _std_var(
  1531. input=input,
  1532. dim=dim,
  1533. unbiased=unbiased,
  1534. correction_opt=correction,
  1535. keepdim=keepdim,
  1536. dtype=dtype,
  1537. mask=mask,
  1538. take_sqrt=False,
  1539. )
  1540. @_apply_docstring_templates
  1541. def std(
  1542. input: Tensor | MaskedTensor,
  1543. dim: DimOrDims = None,
  1544. unbiased: bool | None = None,
  1545. *,
  1546. correction: int | None = None,
  1547. keepdim: bool | None = False,
  1548. dtype: DType | None = None,
  1549. mask: Tensor | None = None,
  1550. ) -> Tensor:
  1551. """\
  1552. {reduction_signature}
  1553. {reduction_descr}
  1554. The identity value of sample standard deviation operation is undefined. The
  1555. elements of output tensor with strided layout, that correspond to
  1556. fully masked-out elements, have ``nan`` values.
  1557. {reduction_args}
  1558. {reduction_example}"""
  1559. return _std_var(
  1560. input=input,
  1561. dim=dim,
  1562. unbiased=unbiased,
  1563. correction_opt=correction,
  1564. keepdim=keepdim,
  1565. dtype=dtype,
  1566. mask=mask,
  1567. take_sqrt=True,
  1568. )
  1569. @_apply_docstring_templates
  1570. def softmax(
  1571. input: Tensor | MaskedTensor,
  1572. dim: int,
  1573. *,
  1574. dtype: DType | None = None,
  1575. mask: Tensor | None = None,
  1576. ) -> Tensor:
  1577. if dtype is None:
  1578. dtype = input.dtype
  1579. dim_ = _canonical_dim(dim, input.ndim)[0]
  1580. mask_input = _combine_input_and_mask(amax, input, mask)
  1581. if mask_input.layout == torch.strided:
  1582. return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
  1583. else:
  1584. raise ValueError(
  1585. f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
  1586. )
  1587. @_apply_docstring_templates
  1588. def log_softmax(
  1589. input: Tensor | MaskedTensor,
  1590. dim: int,
  1591. *,
  1592. dtype: DType | None = None,
  1593. mask: Tensor | None = None,
  1594. ) -> Tensor:
  1595. if dtype is None:
  1596. dtype = input.dtype
  1597. dim_ = _canonical_dim(dim, input.ndim)[0]
  1598. mask_input = _combine_input_and_mask(amax, input, mask)
  1599. if mask_input.layout == torch.strided:
  1600. return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
  1601. else:
  1602. raise ValueError(
  1603. f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
  1604. )
  1605. @_apply_docstring_templates
  1606. def softmin(
  1607. input: Tensor | MaskedTensor,
  1608. dim: int,
  1609. *,
  1610. dtype: DType | None = None,
  1611. mask: Tensor | None = None,
  1612. ) -> Tensor:
  1613. if dtype is None:
  1614. dtype = input.dtype
  1615. dim_ = _canonical_dim(dim, input.ndim)[0]
  1616. mask_input = _combine_input_and_mask(amin, input, mask)
  1617. if mask_input.layout == torch.strided:
  1618. return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
  1619. else:
  1620. raise ValueError(
  1621. f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
  1622. )
  1623. @_apply_docstring_templates
  1624. def normalize(
  1625. input: Tensor | MaskedTensor,
  1626. ord: float,
  1627. dim: int,
  1628. *,
  1629. eps: float = 1e-12,
  1630. dtype: DType | None = None,
  1631. mask: Tensor | None = None,
  1632. ) -> Tensor:
  1633. if dtype is None:
  1634. dtype = input.dtype
  1635. # TODO: eliminate mask_input as unnecessary when using masked divide.
  1636. mask_input = _combine_input_and_mask(sum, input, mask)
  1637. if mask_input.layout == torch.strided:
  1638. nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
  1639. # TODO: replace torch.maximum with masked maximum when available.
  1640. denom = torch.maximum(nrm_, nrm_.new_full([], eps))
  1641. # TODO: replace torch.divide with masked divide when available.
  1642. return torch.divide(mask_input, denom)
  1643. else:
  1644. raise ValueError(
  1645. f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
  1646. )