| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839 |
- # mypy: allow-untyped-defs
- import warnings
- from collections.abc import Callable
- from typing import Any, Optional, TYPE_CHECKING, TypeAlias, TypeVar
- from typing_extensions import ParamSpec
- import torch
- from torch import sym_float, Tensor
- from torch._prims_common import corresponding_real_dtype
- from torch.masked import _docs
- from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
- from torch.masked.maskedtensor.creation import as_masked_tensor
- if TYPE_CHECKING:
- from torch._prims_common import DimsType
- from torch.types import _dtype as DType
- DimOrDims: TypeAlias = DimsType | None
- else:
- # The JIT doesn't understand Union, nor torch.dtype here
- DType = int
- DimOrDims = Optional[tuple[int, ...]]
- __all__: list[str] = []
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- # All masked reduction/normalization operations have the same
- # signatures. Here we introduce docstring templates that are applied
- # to docstrings of reduction/normalization functions via
- # _apply_docstring_templates decorator.
- def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]:
- """Decorator that applies docstring templates to function docstring
- and returns the function instance.
- """
- doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
- if doc_string is None:
- warnings.warn(
- f"No documentation string available for {func.__name__}."
- " PyTorch team should run `python tools/update_masked_docs.py`"
- " to generate the missing docstrings.",
- stacklevel=2,
- )
- else:
- func.__doc__ = doc_string
- # Expose function as public symbol
- __all__.append(func.__name__)
- return func
- def _generate_docstring(func):
- """A utility function called from tools/update_masked_docs.py
- script to update the module torch.masked._docs.py
- """
- docstring_templates = dict(
- reduction_signature="""\
- {function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
- reduction_descr="""\
- Returns {operation name} of all the elements in the :attr:`input`
- tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
- elements are masked out according to the boolean tensor
- :attr:`mask`.""",
- reduction_args="""\
- If :attr:`keepdim` is ``True``, the output tensor is of the same size
- as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
- size 1. Otherwise, :attr:`dim` is squeezed (see
- :func:`torch.squeeze`), resulting in the output tensor having 1 (or
- ``len(dim)``) fewer dimension(s).
- The boolean tensor :attr:`mask` defines the "validity" of
- :attr:`input` tensor elements: if :attr:`mask` element is True
- then the corresponding element in :attr:`input` tensor will be
- included in {operation name} computation, otherwise the element is
- ignored.
- When all elements of :attr:`input` along the given dimension
- :attr:`dim` are ignored (fully masked-out), the corresponding element
- of the output tensor will have undefined value: it may or may not
- correspond to the identity value of {operation name} operation; the
- choice may correspond to the value that leads to the most efficient
- storage of :attr:`output` tensor.
- The mask of the output tensor can be computed as
- ``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
- dtype=torch.bool)``.
- The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
- don't need to match, but they must be :ref:`broadcastable
- <broadcasting-semantics>` and the dimensionality of the :attr:`mask`
- tensor must not be greater than of the :attr:`input` tensor.
- Args:
- input (Tensor): the input tensor
- {args_declarations}
- Keyword args:
- {kwargs_declarations}""",
- reduction_example="""\
- Example::
- >>> input = {example_input}
- >>> input
- {indent_example_input}
- >>> mask = {example_mask}
- >>> mask
- {indent_example_mask}
- >>> {full_function_name}(input, {example_args}, mask=mask)
- {indent_example_output}
- """,
- reduction_identity="""\
- The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
- reduction_identity_dtype="""\
- The identity value of {operation name} operation, which is used to start the
- reduction, depends on input dtype. For instance, for float32, uint8,
- and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
- normalization_signature="""\
- {function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
- normalization_descr="""\
- Returns {operation name} of all the slices in the :attr:`input` tensor
- along :attr:`dim` while the :attr:`input` elements are masked out
- according to the boolean tensor :attr:`mask`.
- {definition}""",
- normalization_args="""\
- The boolean tensor :attr:`mask` defines the "validity" of
- :attr:`input` tensor elements: if :attr:`mask` element is True then
- the corresponding element in :attr:`input` tensor will be included in
- {operation name} computation, otherwise the element is ignored.
- The values of masked-out elements of the output tensor have undefined
- value: it may or may not be set to zero or nan; the choice may correspond to
- the value that leads to the most efficient storage of :attr:`output`
- tensor.
- The mask of the {operation name} output tensor can be computed as
- ``torch.broadcast_to(mask, input.shape)``.
- The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
- don't need to match, but they must be :ref:`broadcastable
- <broadcasting-semantics>` and the dimensionality of the :attr:`mask`
- tensor must not be greater than of the :attr:`input` tensor.
- Args:
- input (Tensor): the input tensor
- {args_declarations}
- Keyword args:
- {kwargs_declarations}""",
- normalization_example="""\
- Example::
- >>> input = {example_input}
- >>> input
- {indent_example_input}
- >>> mask = {example_mask}
- >>> mask
- {indent_example_mask}
- >>> {full_function_name}(input, {example_args}, mask=mask)
- {indent_example_output}
- """,
- )
- args_and_kwargs = {
- # argument name sufficies separated by double underscore will
- # be removed in the final documentation string.
- "sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
- "prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
- "cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
- "cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
- "amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
- "amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
- "argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
- "argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
- "mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
- "median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
- "norm": (
- (
- "ord",
- "dim",
- ),
- ("keepdim=False", "dtype=None", "mask=None"),
- ),
- "var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
- "std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
- "logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
- "softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
- "log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
- "softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
- "normalize": (
- (
- "ord__required",
- "dim__as_int",
- ),
- ("eps=1e-12", "dtype=None", "mask=None"),
- ),
- }
- argument_declarations = {
- "dim": """\
- dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
- Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
- "dim__as_int": """\
- dim (int): the dimension along which {operation name} is computed.""",
- "ord": """\
- ord (int, float, optional): the order of vector norm. Default: 2.
- See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
- "ord__required": """\
- ord (int, float): the order of vector norm. Default: 2.
- See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
- "unbiased": """\
- unbiased (bool): when True, use Bessel's correction, otherwise, compute
- the uncorrected sample variance.""",
- "eps": """\
- eps (float, optional): small value to avoid division by zero. Default: {default}.""",
- "keepdim": """\
- keepdim (bool, optional): whether the output tensor has
- :attr:`dim` retained or not. Default: {default}.""",
- "dtype": """\
- dtype (:class:`torch.dtype`, optional): the desired data type
- of returned tensor. If specified, the input tensor is
- casted to :attr:`dtype` before the operation is
- performed. Default: {default}.""",
- "mask": """\
- mask (:class:`torch.Tensor`, optional): the boolean tensor
- containing the binary mask of validity of input tensor
- elements.
- Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
- }
- definitions = {
- "softmax": """\
- Let ``x`` be a sequence of unmasked elements of one-dimensional slice
- of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
- defined as ``exp(x[i])/sum(exp(x))``.""",
- "log_softmax": """\
- Let ``x`` be a sequence of unmasked elements of one-dimensional slice
- of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
- defined as ``log(exp(x[i])/sum(exp(x)))``.""",
- "softmin": """\
- Let ``x`` be a sequence of unmasked elements of one-dimensional slice
- of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
- defined as ``exp(-x[i])/sum(exp(-x))``.""",
- "normalize": """\
- Let ``x`` be a sequence of unmasked elements of one-dimensional slice
- of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
- defined as ``x[i]/max(norm(x, p), eps)``.""",
- "cumsum": """\
- Let ``x`` be a sequence of unmasked elements of one-dimensional slice
- of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
- defined as ``sum(x[:i])``.""",
- "cumprod": """\
- Let ``x`` be a sequence of unmasked elements of one-dimensional slice
- of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
- defined as ``prod(x[:i])``.""",
- }
- reduction_names = {
- "sum": "sum",
- "prod": "product",
- "amax": "maximum",
- "amin": "minimum",
- "argmax": "argmax",
- "argmin": "argmin",
- "mean": "mean",
- "median": "median",
- "norm": "norm",
- "var": "variance",
- "std": "standard_deviation",
- "logsumexp": "logsumexp",
- }
- normalization_names = {
- "softmax": "softmax",
- "log_softmax": "log_softmax",
- "softmin": "softmin",
- "normalize": "normalize",
- "cumsum": "cumulative_sum",
- "cumprod": "cumulative_prod",
- }
- operation_names = {}
- operation_names.update(reduction_names)
- operation_names.update(normalization_names)
- # Default example data:
- example_dim = 1
- example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
- example_mask = torch.tensor([[True, False, True], [False, False, False]])
- example_args: tuple[Any, ...]
- if func.__name__ in {"norm", "normalize"}:
- example_args = (2.0, example_dim)
- example_input = example_input.to(dtype=torch.float32)
- elif func.__name__ in {"var", "std"}:
- example_args = (example_dim, False)
- elif func.__name__ == "median":
- example_args = (example_dim,)
- example_input = example_input.to(dtype=torch.float32)
- else:
- example_args = (example_dim,)
- operation_args: tuple[str, ...]
- operation_kwargs: tuple[str, ...]
- operation_args, operation_kwargs = args_and_kwargs[func.__name__]
- arg_declarations = [
- "\n ".join(
- argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines()
- )
- for a in operation_args
- ]
- kwarg_declarations = [
- "\n ".join(
- argument_declarations.get(
- a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD."
- )
- .format(default=a.split("=", 1)[1])
- .splitlines()
- )
- for a in operation_kwargs
- ]
- if func.__name__ in reduction_names:
- op_kind = "reduction"
- doc_sections = ["signature", "descr", "identity", "args", "example"]
- elif func.__name__ in normalization_names:
- op_kind = "normalization"
- doc_sections = ["signature", "descr", "args", "example"]
- example_input = example_input.to(dtype=torch.float32)
- else:
- # add function name to operation names dictionaries
- raise AssertionError(f"unknown function {func.__name__}")
- example_output = func(example_input, *example_args, mask=example_mask)
- template_data = {
- "function_name": func.__name__,
- "full_function_name": func.__module__ + "." + func.__name__,
- "operation name": operation_names[func.__name__],
- "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
- "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
- # one-line representation of a tensor:
- "example_input": " ".join(str(example_input).split()),
- "example_args": ", ".join(map(str, example_args)),
- "example_mask": " ".join(str(example_mask).split()),
- # multi-line representation of a tensor with indent
- "indent_example_input": ("\n ").join(str(example_input).splitlines()),
- "indent_example_mask": ("\n ").join(str(example_mask).splitlines()),
- "indent_example_output": ("\n ").join(str(example_output).splitlines()),
- }
- if func.__name__ in reduction_names:
- template_data.update(
- identity_uint8=_reduction_identity(
- func.__name__, torch.tensor(0, dtype=torch.uint8)
- ),
- identity_int32=_reduction_identity(
- func.__name__, torch.tensor(0, dtype=torch.int32)
- ),
- identity_float32=_reduction_identity(
- func.__name__, torch.tensor(0, dtype=torch.float32)
- ),
- )
- if func.__name__ == "norm":
- template_data.update(
- identity_ord_ninf=_reduction_identity(
- func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
- )
- )
- elif func.__name__ in normalization_names:
- template_data.update(definition=definitions[func.__name__])
- else:
- # add function name to operation names dictionaries
- raise AssertionError(f"unknown function {func.__name__}")
- template_data.update(
- args_declarations=("\n ".join(arg_declarations)).format_map(template_data)
- )
- template_data.update(
- kwargs_declarations=("\n ".join(kwarg_declarations)).format_map(
- template_data
- )
- )
- # Apply function name info to docstring templates:
- templates = {
- k: v.format_map(template_data)
- for k, v in docstring_templates.items()
- if k.startswith(op_kind)
- }
- templates.update(
- (k, v.format_map(template_data) if isinstance(v, str) else v)
- for k, v in template_data.items()
- )
- # Apply docstring templates to function doctring:
- if func.__doc__ is None:
- doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
- else:
- doc_template = func.__doc__
- return doc_template.format_map(templates)
- def _reduction_identity(op_name: str, input: Tensor, *args):
- """Return identity value as scalar tensor of a reduction operation on
- given input, or None, if the identity value cannot be uniquely
- defined for the given input.
- The identity value of the operation is defined as the initial
- value to reduction operation that has a property ``op(op_identity,
- value) == value`` for any value in the domain of the operation.
- Or put it another way, including or excluding the identity value in
- a list of operands will not change the reduction result.
- See https://github.com/pytorch/rfcs/pull/27 for more information.
- """
- dtype: DType = input.dtype
- device = input.device
- op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present
- if op_name in {"sum", "cumsum"}:
- return torch.tensor(0, dtype=dtype, device=device)
- elif op_name in {"prod", "cumprod"}:
- return torch.tensor(1, dtype=dtype, device=device)
- elif op_name in {"amax", "argmax", "logaddexp"}:
- if torch.is_floating_point(input):
- return torch.tensor(-torch.inf, dtype=dtype, device=device)
- elif torch.is_signed(input) or dtype == torch.uint8:
- return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
- elif op_name == "logsumexp":
- if torch.is_floating_point(input):
- return torch.tensor(-torch.inf, dtype=dtype, device=device)
- elif torch.is_complex(input):
- return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device)
- elif torch.is_signed(input) or dtype == torch.uint8:
- return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
- elif op_name in {"amin", "argmin"}:
- if torch.is_floating_point(input):
- return torch.tensor(torch.inf, dtype=dtype, device=device)
- elif torch.is_signed(input) or dtype == torch.uint8:
- return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
- elif op_name == "mean":
- # Strictly speaking, the identity value of the mean operation
- # is the mean of the input. Since the mean value depends on
- # the dim argument and it may be a non-scalar tensor, we
- # consider the identity value of the mean operation ambiguous.
- # Moreover, the mean value of empty input is undefined.
- return None
- elif op_name == "norm":
- ord = args[0] if args else 2
- if ord == float("-inf"):
- if not torch.is_floating_point(input):
- raise AssertionError(f"input must be floating point, got {input.dtype}")
- return torch.tensor(torch.inf, dtype=dtype, device=device)
- return torch.tensor(0, dtype=dtype, device=device)
- elif op_name == "median":
- # We use NaN for now because the implementation is currently using torch.nanmedian
- # and NaN is the identity for that function since it gets ignored
- dtype = input.dtype if torch.is_floating_point(input) else torch.float
- return torch.tensor(torch.nan, dtype=dtype, device=device)
- elif op_name in {"var", "std"}:
- return None
- raise NotImplementedError(f"identity of {op_name} on {dtype} input")
- def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
- """Return dim argument as a tuple of sorted dim values."""
- dims: list[int] = []
- if dim == ():
- # Currently, `dim=()` in reductions operations means "reduce
- # over all dimensions" while in future, it will read "no
- # reduce". See https://github.com/pytorch/pytorch/issues/29137
- # When gh-29137 is resolved, this if-block must be deleted.
- dim = None
- if dim is None:
- return tuple(range(ndim))
- ndim = max(ndim, 1)
- dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
- for d in dim_:
- if d in dims:
- raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
- if d >= ndim or d < -ndim:
- raise IndexError(
- f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
- )
- # pyrefly: ignore [bad-argument-type]
- dims.append(d % ndim)
- return tuple(sorted(dims))
- def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
- # Flatted N-D indices to 1-D indices
- flat_indices = indices.new_zeros(indices.size(1))
- for d, sz in enumerate(shape):
- flat_indices.mul_(sz)
- flat_indices.add_(indices[d])
- return flat_indices
- def _any(input: Tensor, dim: tuple, keepdim: bool):
- # Support torch.any with tuple dim argument.
- # Workaround of https://github.com/pytorch/pytorch/issues/56586
- r = input
- for d in reversed(dim):
- r = r.any(dim=d, keepdim=keepdim)
- return r
- def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
- """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
- _sparse_coo_where implements the following invariant:
- _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
- torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
- where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
- tensor, and `to_dense(fill_value)` is like `to_dense()` except
- that the unspecified elements are mapped to `fill_value` rather
- than to `0`.
- Returns a sparse COO tensor with the following features:
- - all specified elements correspond to masked-in elements that
- have the values of the input tensor. If there exists a masked-in
- element (as specified by mask) that is not specified in the
- input, in the result tensor, the corresponding element has value
- 0. In the dense part of the sparse tensor, the masked-out
- elements are replaced with fill_value.
- - all unspecified elements correspond to masked-out elements.
- """
- if input.layout != torch.sparse_coo:
- raise AssertionError(f"input.layout must be sparse_coo, got {input.layout}")
- if mask.layout != input.layout:
- raise AssertionError(f"mask.layout must match input.layout, got {mask.layout}")
- if mask.shape != input.shape:
- raise AssertionError(
- f"mask.shape must match input.shape: {mask.shape} vs {input.shape}"
- )
- if mask.dense_dim() != input.dense_dim():
- # TODO: eliminate this restriction
- raise AssertionError(
- f"mask.dense_dim() must match input.dense_dim(): "
- f"{mask.dense_dim()} vs {input.dense_dim()}"
- )
- input = input.coalesce()
- # For set operations on sparse tensor indices, we'll convert
- # multi-dimensional indices to 1-D indices for efficiency.
- input_flat_indices = _sparse_coo_flatten_indices(
- input.indices(), input.shape[: input.sparse_dim()]
- )
- mask_flat_indices = _sparse_coo_flatten_indices(
- mask.indices(), mask.shape[: mask.sparse_dim()]
- )
- # the set of mask flat indices that define masked-in elements:
- if mask.dense_dim() > 0:
- mask_values = _any(
- mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
- )
- else:
- mask_values = mask.values()
- maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
- def intersection(i1, i2):
- union, counts = torch.cat([i1, i2]).unique(return_counts=True)
- return union, torch.where(counts.gt(1))
- def minus(i1, i2):
- union, counts = torch.cat([i1, i2]).unique(return_counts=True)
- return intersection(union[torch.where(counts.eq(1))], i1)
- def _apply(a):
- obj, w = a
- return obj[w]
- # the set of input flat indices of specified and masked-in elements:
- maskin_input_flat_indices = _apply(
- intersection(maskin_flat_indices, input_flat_indices)
- )
- _, w = intersection(input_flat_indices, maskin_input_flat_indices)
- # the indices and values of masked-in elements
- where_input_indices = input.indices()[(slice(None),) + w]
- where_input_values = input.values()[w]
- if mask.dense_dim() > 0:
- # apply mask to the dense part of the input values:
- _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
- where_mask_values = mask.values()[w1]
- where_input_values = torch.where(
- where_mask_values, where_input_values, fill_value
- )
- # the set of flat indices of unspecified input and masked-in elements:
- maskin_zero_flat_indices = _apply(
- minus(maskin_flat_indices, maskin_input_flat_indices)
- )
- # the indices of masked-in zero elements
- _, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
- where_zero_indices = mask.indices()[(slice(None),) + w]
- # construct result
- n = where_zero_indices.size(1)
- if n == 0:
- # the input is coalesced, hence input_flat_indices are ordered
- # and the result is guaranteed to be coalesced:
- result = torch.sparse_coo_tensor(
- where_input_indices, where_input_values, input.shape
- )
- return result._coalesced_(True)
- where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
- where_values = torch.cat(
- [
- where_input_values,
- where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
- ]
- )
- result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
- # appending zero elements leads to uncoalesced sparse tensor
- return result.coalesce()
- def _sparse_coo_scatter_reduction_helper(
- op,
- mask_input: Tensor,
- dims: tuple[int, ...],
- keepdim: bool,
- dtype: DType | None = None,
- ) -> Tensor:
- reduce = op.__name__
- valid_reductions = ["sum", "prod", "amax", "amin"]
- if reduce not in valid_reductions:
- raise ValueError(
- f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
- )
- output_dtype = dtype
- values, indices = mask_input._values(), mask_input._indices()
- input_dims = mask_input.dim()
- num_sparse_dims = mask_input.sparse_dim()
- reduced_sparse_dims = []
- retained_sparse_dims = []
- reduced_dense_dims = []
- # promote dtype if specified
- if values.dtype != output_dtype:
- values = values.to(output_dtype)
- if keepdim:
- output_shape = tuple(
- 1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
- )
- else:
- output_shape = tuple(
- si for (i, si) in enumerate(mask_input.shape) if i not in dims
- )
- for d in dims:
- if d >= input_dims:
- continue
- if d < num_sparse_dims:
- reduced_sparse_dims.append(d)
- else:
- reduced_dense_dims.append(d + 1 - num_sparse_dims)
- # Reduce dense dimensions
- if len(reduced_dense_dims) > 0:
- if reduce == "sum":
- new_values = values
- new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
- else:
- # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
- return NotImplemented
- else:
- new_values = values.clone()
- # Reduce sparse dimensions
- if len(reduced_sparse_dims) == num_sparse_dims:
- if reduce in {"amax", "amin"} and new_values.size(0) == 0:
- # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
- # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
- # See https://github.com/pytorch/pytorch/issues/61901
- new_values = _reduction_identity(reduce, new_values)
- else:
- new_values = op(new_values, dim=0)
- if keepdim:
- for _ in range(num_sparse_dims):
- new_values = new_values.unsqueeze(0)
- return new_values.to(dtype=output_dtype).to_sparse()
- else:
- new_indices = indices.clone()
- if keepdim:
- # zero out reduced sparse dimensions if keepdim = True
- # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
- new_indices[reduced_sparse_dims, :] = 0
- else:
- # remove reduced sparse dimensions if keepdim = False
- if len(reduced_sparse_dims) > 0:
- retained_sparse_dims = [
- i
- for i in range(num_sparse_dims)
- if i not in set(reduced_sparse_dims)
- ]
- new_indices = new_indices.index_select(
- 0, torch.tensor(retained_sparse_dims).to(mask_input.device)
- )
- # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
- if new_indices.numel() > 0:
- # lexsort indices and get index tensor for scatter reduction
- new_indices, inverse_indices = torch.unique(
- new_indices, return_inverse=True, dim=1
- )
- out_shape = list(new_values.shape)
- out_shape[0] = new_indices.shape[1]
- for _ in range(new_values.ndim - 1):
- inverse_indices = inverse_indices.unsqueeze(-1)
- scatter_indices = inverse_indices.expand(new_values.shape)
- # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
- if output_dtype in {torch.bfloat16, torch.float16}:
- new_values = new_values.to(torch.float)
- out = new_values.new_empty(out_shape)
- new_values = out.scatter_reduce_(
- 0, scatter_indices, new_values, reduce=reduce, include_self=False
- )
- new_values = new_values.to(dtype=output_dtype)
- else:
- out = new_values.new_empty(out_shape)
- new_values = out.scatter_reduce_(
- 0, scatter_indices, new_values, reduce=reduce, include_self=False
- )
- return torch.sparse_coo_tensor(
- new_indices,
- new_values,
- output_shape,
- dtype=output_dtype,
- device=mask_input.device,
- )
- def _sparse_csr_segment_reduction_helper(
- op,
- mask_input: Tensor,
- dims: tuple[int, ...],
- keepdim: bool,
- dtype: DType | None = None,
- ) -> Tensor:
- # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
- # FIXME: when dense dimensions are implemented for CSR tensors
- if not keepdim:
- raise AssertionError(
- "reduction operations on CSR tensors with keepdim=False is unsupported"
- )
- reduce = op.__name__
- valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
- if reduce not in valid_reductions:
- raise ValueError(
- f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
- )
- device = mask_input.device
- output_dtype = dtype
- values, crow_indices, col_indices = (
- mask_input.values(),
- mask_input.crow_indices(),
- mask_input.col_indices(),
- )
- # promote dtype if specified
- if values.dtype != output_dtype:
- values = values.to(output_dtype)
- if len(dims) == 0:
- return mask_input
- if len(dims) == 1:
- if dims[0] == 0:
- new_col_indices, scatter_indices = torch.unique(
- col_indices, return_inverse=True
- )
- new_nnz = new_col_indices.shape[0]
- new_crow_indices = torch.tensor([0, new_nnz])
- new_values = values.new_empty(new_col_indices.shape)
- new_values.scatter_reduce_(
- 0, scatter_indices, values, reduce, include_self=False
- )
- new_shape = [1, mask_input.size(1)]
- else:
- if dims[0] != 1:
- raise AssertionError(
- "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
- )
- # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
- # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
- new_crow_indices = torch.cat(
- (
- crow_indices.new_zeros(1),
- torch.cumsum(torch.diff(crow_indices) != 0, 0),
- ),
- 0,
- )
- new_nnz = new_crow_indices[-1]
- new_col_indices = col_indices.new_zeros(new_nnz) # type: ignore[call-overload]
- new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
- new_shape = [mask_input.size(0), 1]
- else:
- if len(dims) != 2:
- raise AssertionError(f"expected len(dims) == 2, got {len(dims)}")
- nnz = min(1, values.numel())
- if nnz == 1:
- op_kwargs = {"keepdim": True, "dtype": output_dtype}
- # amax and amin do not support dtype kwarg
- if reduce in ["amax", "amin"]:
- del op_kwargs["dtype"]
- new_values = op(values, 0, **op_kwargs)
- else:
- new_values = torch.empty(0, dtype=output_dtype)
- new_col_indices = col_indices.new_zeros(nnz)
- new_crow_indices = torch.tensor([0, nnz])
- new_shape = [1, nnz]
- return torch.sparse_csr_tensor(
- new_crow_indices,
- new_col_indices,
- new_values,
- new_shape,
- dtype=output_dtype,
- device=device,
- )
- def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
- """Sparse variant of torch.where. Supports sparse CSR tensors."""
- # TODO: implement sparse CSR specific where operator for efficiency
- return _sparse_coo_where(
- mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
- ).to_sparse_csr()
- def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
- """torch.where with sparse inputs support.
- _where implements the following invariant:
- _where(mask, input, fill_value).to_dense(fill_value) ==
- torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
- where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
- tensor, and `to_dense(fill_value)` is like `to_dense()` except
- that the unspecified elements are mapped to `fill_value` rather
- than to `0`.
- Returns a sparse tensor with the following features:
- - all specified elements correspond to masked-in elements that
- have the values of the input tensor. If there exists a masked-in
- element (as specified by mask) that is not specified in the
- input, in the result tensor, the corresponding element has value
- 0. In the dense part of the sparse tensor, the masked-out
- elements are replaced with fill_value.
- - all unspecified elements correspond to masked-out elements.
- """
- if mask.layout == torch.strided:
- return torch.where(mask, input, fill_value)
- elif mask.layout == torch.sparse_coo:
- return _sparse_coo_where(mask, input, fill_value)
- elif mask.layout == torch.sparse_csr:
- return _sparse_csr_where(mask, input, fill_value)
- else:
- raise ValueError(
- f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
- )
- def _input_mask(input: Tensor | MaskedTensor, *args, **kwargs) -> Tensor:
- """Return canonical input mask.
- A canonical input mask is defined as a boolean mask tensor that
- shape and layout matches with the shape and the layout of the
- input.
- The canonical input mask is computed from the :attr:`mask` tensor
- content to meet the following criteria:
- 1. The shape of the canonical input mask is the same as the shape
- of :attr:`input` tensor. If the mask tensor has a smaller shape
- than the shape of the :attr:`input`, broadcasting rules will be
- applied. Downcasting of mask is not supported.
- 2. The layout of the canonical input mask is the same as the
- layout of the :attr:`input` tensor. If the mask has different
- layout, it will be converted to the expected layout. In the
- case of sparse COO layout, the canonical input mask will be
- coalesced.
- 3. The dtype of the canonical input mask is torch.bool. If the
- mask dtype is not bool then it will be converted to bool dtype
- using `.to(dtype=bool)` method call.
- 4. The elements of the canonical input mask have boolean values
- copied from the content of the :attr:`mask` tensor (after
- possible broadcasting and dtype conversion transforms). In
- general, the sparsity pattern of the sparse canonical input
- mask need not to be the same as the sparsity pattern of the
- sparse :attr:`input` tensor.
- """
- if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
- raise ValueError(
- f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
- )
- mask = kwargs.get("mask")
- # default mask
- if mask is None:
- raise ValueError("_input_mask requires explicit mask")
- # mask shape must match with input shape
- if mask.shape != input.shape:
- if mask.ndim > input.ndim:
- raise IndexError(
- "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
- )
- if mask.layout == torch.strided:
- mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
- elif mask.layout == torch.sparse_coo:
- mask = torch._sparse_broadcast_to(mask, input.shape)
- else:
- if mask.layout != torch.sparse_csr:
- raise AssertionError(f"expected sparse_csr layout, got {mask.layout}")
- # Broadcasting of CSR tensors is not implemented. Working
- # around by using COO layout.
- mask = torch._sparse_broadcast_to(
- mask.to_sparse(), input.shape
- ).to_sparse_csr()
- # mask layout must match with input layout
- if mask.layout != input.layout:
- if input.layout == torch.strided:
- mask = mask.to_dense()
- elif input.layout == torch.sparse_coo:
- if mask.layout == torch.strided:
- mask = mask.to_sparse(input.sparse_dim())
- else:
- mask = mask.to_sparse()
- else:
- if input.layout != torch.sparse_csr:
- raise AssertionError(f"expected sparse_csr layout, got {input.layout}")
- mask = mask.to_sparse_csr()
- # sparse mask must be coalesced
- if mask.layout == torch.sparse_coo:
- mask = mask.coalesce()
- # mask is a boolean tensor
- mask = mask.to(dtype=torch.bool)
- return mask
- def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
- """Return output mask of masked operation applied to given arguments."""
- if callable(op):
- is_reduction = op.__name__ in {
- "sum",
- "prod",
- "amax",
- "amin",
- "argmax",
- "argmin",
- "mean",
- "median",
- "norm",
- "var",
- "std",
- "logsumexp",
- }
- is_normalization = op.__name__ in {
- "softmax",
- "log_softmax",
- "softmin",
- "normalize",
- "cumsum",
- "cumprod",
- }
- if is_reduction:
- if op.__name__ == "norm":
- if args:
- args = args[1:] # lstrip ord argument
- dim = args[0] if args else kwargs.get("dim")
- outmask = _input_mask(input, *args, **kwargs)
- keepdim = kwargs.get("keepdim", False)
- dim_ = _canonical_dim(dim, input.ndim)
- return _any(outmask, dim_, bool(keepdim))
- elif is_normalization:
- return _input_mask(input, *args, **kwargs)
- else:
- raise ValueError(
- f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
- )
- else:
- raise ValueError(
- f"_output_mask expected masked operation (got {type(op).__name__} object)"
- )
- def _combine_input_and_mask(op, input: MaskedTensor | Tensor, mask, *args) -> Tensor:
- def helper(input, mask):
- if mask is None:
- return input
- canonical_mask = _input_mask(input, mask=mask)
- if callable(op):
- fill_value = _reduction_identity(op.__name__, input, *args)
- return _where(canonical_mask, input, fill_value)
- else:
- raise ValueError(
- f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
- )
- class Combine(torch.autograd.Function):
- @staticmethod
- # pyrefly: ignore [bad-override]
- def forward(ctx, input, mask):
- """Return input with masked-out elements eliminated for the given operations."""
- ctx.save_for_backward(mask)
- if mask is not None:
- ctx.mark_non_differentiable(mask)
- return helper(input, mask)
- @staticmethod
- # pyrefly: ignore [bad-override]
- def backward(ctx, grad_output):
- (mask,) = ctx.saved_tensors
- grad_data = (
- grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
- )
- result = as_masked_tensor(grad_data, mask)
- return result, None
- return (
- Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr]
- if is_masked_tensor(input)
- else helper(input, mask)
- )
- @_apply_docstring_templates
- def sum(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- # __doc__ is generated by _apply_docstring_templates decorator
- if dtype is None:
- # promote integer types to int64 when output dtype is not specified
- if input.layout == torch.sparse_csr:
- if input.dtype in {
- torch.uint8,
- torch.bool,
- torch.int8,
- torch.int16,
- torch.int32,
- }:
- # csr.to(dtype=torch.int64) is not implemented, so
- # using coo.to on input to ensure the promoted dtype
- input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
- else:
- dtype = input.dtype
- else:
- dtype = input.dtype
- if input.dtype in {
- torch.uint8,
- torch.bool,
- torch.int8,
- torch.int16,
- torch.int32,
- }:
- dtype = torch.int64
- dim_ = _canonical_dim(dim, input.ndim)
- mask_input = _combine_input_and_mask(sum, input, mask)
- if mask_input.layout == torch.strided:
- return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
- elif mask_input.layout == torch.sparse_coo:
- return _sparse_coo_scatter_reduction_helper(
- torch.sum, mask_input, dim_, bool(keepdim), dtype
- )
- elif mask_input.layout == torch.sparse_csr:
- return torch._sparse_csr_sum(
- mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
- )
- else:
- raise ValueError(
- f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def prod(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- # __doc__ is generated by _apply_docstring_templates decorator
- if dtype is None:
- # promote integer types to int64 when output dtype is not specified
- if input.layout == torch.sparse_csr:
- if input.dtype in {
- torch.uint8,
- torch.bool,
- torch.int8,
- torch.int16,
- torch.int32,
- }:
- # csr.to(dtype=torch.int64) is not implemented, so
- # using coo.to on input to ensure the promoted dtype
- input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
- else:
- dtype = input.dtype
- else:
- dtype = input.dtype
- if input.dtype in {
- torch.uint8,
- torch.bool,
- torch.int8,
- torch.int16,
- torch.int32,
- }:
- dtype = torch.int64
- dim_ = _canonical_dim(dim, input.ndim)
- mask_input = _combine_input_and_mask(prod, input, mask)
- if mask_input.layout == torch.strided:
- # Workaround https://github.com/pytorch/pytorch/issues/56586
- result = mask_input
- result = result.to(dtype=dtype)
- for d in reversed(dim_):
- result = result.prod(dim=d, keepdim=bool(keepdim))
- return result
- elif mask_input.layout == torch.sparse_coo:
- if mask is None:
- # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
- raise ValueError(
- "masked prod expects explicit mask for sparse_coo tensor input"
- )
- return _sparse_coo_scatter_reduction_helper(
- torch.prod, mask_input, dim_, bool(keepdim), dtype
- )
- elif mask_input.layout == torch.sparse_csr:
- if mask is None:
- # mask is None corresponds to all-True mask. The
- # unspecified elements in the CSR tensor correspond to
- # zero values. Hence, the prod reduction result is
- # automatically zero unless all elements are specified.
- # A semi-optimal way to take this into account is to use:
- #
- # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
- #
- # but that requires implementing `all` and `nonzero`
- # support for sparse csr tensors.
- raise ValueError(
- "masked prod expects explicit mask for sparse_csr tensor input"
- )
- return torch._sparse_csr_prod(
- mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
- )
- else:
- raise ValueError(
- f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def cumsum(
- input: Tensor,
- dim: int,
- *,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)[0]
- mask_input = _combine_input_and_mask(sum, input, mask)
- if mask_input.layout == torch.strided:
- return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
- else:
- raise ValueError(
- f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def cumprod(
- input: Tensor,
- dim: int,
- *,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)[0]
- mask_input = _combine_input_and_mask(prod, input, mask)
- if mask_input.layout == torch.strided:
- return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
- else:
- raise ValueError(
- f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def amax(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- {reduction_identity_dtype}
- {reduction_args}
- {reduction_example}"""
- if dtype is None:
- dtype = input.dtype
- mask_input = _combine_input_and_mask(amax, input, mask)
- dim_ = _canonical_dim(dim, mask_input.ndim)
- if mask_input.layout == torch.strided:
- return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
- elif mask_input.layout == torch.sparse_coo:
- if mask is None:
- # See comment in the sparse_csr branch of prod, a similar issue arises here
- # where unspecified elements along a dimension may need to be reduced with the result
- raise ValueError(
- "masked amax expects explicit mask for sparse_coo tensor input"
- )
- return _sparse_coo_scatter_reduction_helper(
- torch.amax, mask_input, dim_, bool(keepdim), dtype
- )
- elif mask_input.layout == torch.sparse_csr:
- if mask is None:
- raise ValueError(
- "masked amax expects explicit mask for sparse_csr tensor input"
- )
- return _sparse_csr_segment_reduction_helper(
- torch.amax, mask_input, dim_, bool(keepdim), dtype
- )
- else:
- raise ValueError(
- f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def amin(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- {reduction_identity_dtype}
- {reduction_args}
- {reduction_example}"""
- if dtype is None:
- dtype = input.dtype
- mask_input = _combine_input_and_mask(amin, input, mask)
- dim_ = _canonical_dim(dim, mask_input.ndim)
- if mask_input.layout == torch.strided:
- return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
- elif mask_input.layout == torch.sparse_coo:
- if mask is None:
- # See comment in the sparse_csr branch of prod, a similar issue arises here
- # where unspecified elements along a dimension may need to be reduced with the result
- raise ValueError(
- "masked amax expects explicit mask for sparse_coo tensor input"
- )
- return _sparse_coo_scatter_reduction_helper(
- torch.amin, mask_input, dim_, bool(keepdim), dtype
- )
- elif mask_input.layout == torch.sparse_csr:
- if mask is None:
- raise ValueError(
- "masked amin expects explicit mask for sparse_csr tensor input"
- )
- return _sparse_csr_segment_reduction_helper(
- torch.amin, mask_input, dim_, bool(keepdim), dtype
- )
- else:
- raise ValueError(
- f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def argmax(
- input: Tensor | MaskedTensor,
- dim: int | None = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- {reduction_identity_dtype}
- {reduction_args}
- {reduction_example}"""
- if dtype is None:
- dtype = input.dtype
- mask_input = _combine_input_and_mask(argmax, input, mask)
- if mask_input.layout == torch.strided:
- return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
- else:
- raise ValueError(
- f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def argmin(
- input: Tensor | MaskedTensor,
- dim: int | None = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- {reduction_identity_dtype}
- {reduction_args}
- {reduction_example}"""
- if dtype is None:
- dtype = input.dtype
- mask_input = _combine_input_and_mask(argmin, input, mask)
- if mask_input.layout == torch.strided:
- return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
- else:
- raise ValueError(
- f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def mean(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- By definition, the identity value of a mean operation is the mean
- value of the tensor. If all elements of the input tensor along given
- dimension(s) :attr:`dim` are masked-out, the identity value of the
- mean is undefined. Due to this ambiguity, the elements of output
- tensor with strided layout, that correspond to fully masked-out
- elements, have ``nan`` values.
- {reduction_args}
- {reduction_example}"""
- dtype_source = "Optional"
- if dtype is None:
- dtype = input.dtype
- dtype_source = "Input"
- if not (dtype.is_floating_point or dtype.is_complex):
- raise ValueError(
- f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
- f"a floating point or complex dtype. Got: {dtype}"
- )
- if input.layout == torch.strided:
- if mask is None:
- # TODO: compute count analytically
- # pyrefly: ignore [no-matching-overload]
- count = sum(
- torch.ones(input.shape, dtype=torch.int64, device=input.device),
- dim,
- keepdim=keepdim,
- )
- # pyrefly: ignore [no-matching-overload]
- total = sum(input, dim, keepdim=keepdim, dtype=dtype)
- else:
- inmask = _input_mask(input, mask=mask)
- count = inmask.sum(dim=dim, keepdim=bool(keepdim))
- # pyrefly: ignore [no-matching-overload]
- total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
- return total / count
- elif input.layout == torch.sparse_csr:
- mask_input = _combine_input_and_mask(mean, input, mask)
- dim_ = _canonical_dim(dim, mask_input.ndim)
- if mask is None:
- raise ValueError(
- "masked mean expects explicit mask for sparse_csr tensor input"
- )
- return _sparse_csr_segment_reduction_helper(
- torch.mean, mask_input, dim_, bool(keepdim), dtype
- )
- else:
- raise ValueError(
- f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
- )
- @_apply_docstring_templates
- def median(
- input: Tensor | MaskedTensor,
- dim: int = -1,
- *,
- keepdim: bool = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- By definition, the identity value of a median operation is the median
- value of the tensor. If all elements of the input tensor along given
- dimension(s) :attr:`dim` are masked-out, the identity value of the
- median is undefined. Due to this ambiguity, the elements of output
- tensor with strided layout, that correspond to fully masked-out
- elements, have ``nan`` values.
- {reduction_args}
- {reduction_example}"""
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)[0]
- is_float = torch.is_floating_point(input)
- if not is_float:
- input = input.to(dtype=torch.float)
- mask_input = _combine_input_and_mask(median, input, mask)
- if mask_input.layout == torch.strided:
- output = torch.nanmedian(mask_input, dim_, keepdim).values
- if is_float:
- return output
- elif not is_float and not torch.isnan(output).any():
- return output.to(dtype=dtype)
- else:
- raise ValueError(
- "masked median expects no fully masked out rows if dtype is not floating point"
- )
- else:
- raise ValueError(
- f"masked median expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def logsumexp(
- input: Tensor,
- dim: DimOrDims = None,
- *,
- keepdim: bool = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)
- mask_input = _combine_input_and_mask(logsumexp, input, mask)
- if mask_input.layout == torch.strided:
- return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
- else:
- raise ValueError(
- f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
- )
- # Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
- def logaddexp(
- input: Tensor | MaskedTensor,
- other: Tensor | MaskedTensor,
- *,
- dtype: DType | None = None,
- input_mask: Tensor | None = None,
- other_mask: Tensor | None = None,
- ) -> Tensor:
- """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
- Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
- tensor. The :attr:`input` elements are masked out according to the boolean tensor
- :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
- :attr:`other_mask`.
- The shapes of a mask tensor and the tensor to be masked
- don't need to match, but they must be :ref:`broadcastable
- <broadcasting-semantics>` and the dimensionality of the mask
- tensor must not be greater than of the tensor to be masked.
- Args:
- input (Tensor): the input tensor
- other (Tensor): the second input tensor
- Keyword args:
- dtype (:class:`torch.dtype`, optional): the desired data type
- of returned tensor. If specified, the output tensor is
- casted to :attr:`dtype` after the operation is
- performed. Default: None.
- input_mask (:class:`torch.Tensor`, optional): the boolean tensor
- containing the binary mask of validity of :attr:`input` tensor elements.
- Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
- other_mask (:class:`torch.Tensor`, optional): the boolean tensor
- containing the binary mask of validity of :attr:`other` tensor elements.
- Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
- Example::
- >>> input = torch.tensor([-100.0, -200, -300])
- >>> input
- tensor([-100., -200., -300.])
- >>> other = torch.tensor([-1.0, -2, -3])
- >>> other
- tensor([-1., -2., -3.])
- >>> mask = torch.tensor([True, False, True])
- >>> mask
- tensor([ True, False, True])
- >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
- tensor([-1., -inf, -3.])"""
- if dtype is None:
- dtype = input.dtype
- if input.layout == torch.strided and other.layout == torch.strided:
- mask_input = _combine_input_and_mask(logaddexp, input, input_mask)
- mask_other = _combine_input_and_mask(logaddexp, other, other_mask)
- return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
- else:
- raise ValueError(
- f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
- )
- @_apply_docstring_templates
- def norm(
- input: Tensor | MaskedTensor,
- ord: float | None = 2.0,
- dim: DimOrDims = None,
- *,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- The identity value of norm operation, which is used to start the
- reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
- ``{identity_ord_ninf}``.
- {reduction_args}
- {reduction_example}"""
- if dtype is None:
- dtype = input.dtype
- mask_input = _combine_input_and_mask(norm, input, mask, ord)
- if mask_input.layout == torch.strided:
- dim_ = _canonical_dim(dim, input.ndim)
- return torch.linalg.vector_norm(
- mask_input, ord, dim_, bool(keepdim), dtype=dtype
- )
- else:
- raise ValueError(
- f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
- )
- def _std_var(
- input: Tensor | MaskedTensor,
- dim: DimOrDims,
- unbiased: bool | None,
- *,
- correction_opt: int | float | None,
- keepdim: bool | None,
- dtype: DType | None,
- mask: Tensor | None,
- take_sqrt: bool | None,
- ) -> Tensor:
- if unbiased is not None and correction_opt is not None:
- raise AssertionError("Only one of unbiased and correction may be given")
- correction = 1.0
- if unbiased is not None:
- correction = 1.0 if unbiased else 0.0
- if correction_opt is not None:
- correction = sym_float(correction_opt)
- if dtype is None:
- dtype = input.dtype
- if not (dtype.is_floating_point or dtype.is_complex):
- dtype = torch.float32
- compute_dtype = dtype
- if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
- compute_dtype = torch.float32
- if input.layout == torch.strided:
- if mask is None:
- # TODO: compute count analytically
- # pyrefly: ignore [no-matching-overload]
- count = sum(
- torch.ones(input.shape, dtype=torch.int64, device=input.device),
- dim,
- keepdim=True,
- )
- # pyrefly: ignore [no-matching-overload]
- sample_total = sum(input, dim, keepdim=True, dtype=dtype)
- else:
- inmask = _input_mask(input, mask=mask)
- count = inmask.sum(dim=dim, keepdim=True)
- # pyrefly: ignore [no-matching-overload]
- sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
- # TODO: replace torch.subtract/divide/square/maximum with
- # masked subtract/divide/square/maximum when these will be
- # available.
- sample_mean = torch.divide(sample_total, count)
- x = torch.subtract(input, sample_mean)
- if mask is None:
- # pyrefly: ignore [no-matching-overload]
- total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
- else:
- # pyrefly: ignore [no-matching-overload]
- total = sum(
- x * x.conj(),
- dim,
- keepdim=keepdim,
- dtype=compute_dtype,
- mask=inmask, # type: ignore[possibly-undefined]
- )
- if not keepdim:
- count = count.reshape(total.shape)
- if correction != 0:
- real_dtype = (
- corresponding_real_dtype(compute_dtype)
- if compute_dtype.is_complex
- else compute_dtype
- )
- count = count.to(real_dtype)
- count = torch.subtract(count, correction)
- count = torch.maximum(count, count.new_zeros([]))
- output = torch.divide(total, count).to(dtype=dtype)
- if take_sqrt:
- output = torch.sqrt(output)
- return output
- else:
- raise ValueError(
- f"masked std/var expects strided tensor (got {input.layout} tensor)"
- )
- @_apply_docstring_templates
- def var(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- unbiased: bool | None = None,
- *,
- correction: int | float | None = None,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- The identity value of sample variance operation is undefined. The
- elements of output tensor with strided layout, that correspond to
- fully masked-out elements, have ``nan`` values.
- {reduction_args}
- {reduction_example}"""
- return _std_var(
- input=input,
- dim=dim,
- unbiased=unbiased,
- correction_opt=correction,
- keepdim=keepdim,
- dtype=dtype,
- mask=mask,
- take_sqrt=False,
- )
- @_apply_docstring_templates
- def std(
- input: Tensor | MaskedTensor,
- dim: DimOrDims = None,
- unbiased: bool | None = None,
- *,
- correction: int | None = None,
- keepdim: bool | None = False,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- """\
- {reduction_signature}
- {reduction_descr}
- The identity value of sample standard deviation operation is undefined. The
- elements of output tensor with strided layout, that correspond to
- fully masked-out elements, have ``nan`` values.
- {reduction_args}
- {reduction_example}"""
- return _std_var(
- input=input,
- dim=dim,
- unbiased=unbiased,
- correction_opt=correction,
- keepdim=keepdim,
- dtype=dtype,
- mask=mask,
- take_sqrt=True,
- )
- @_apply_docstring_templates
- def softmax(
- input: Tensor | MaskedTensor,
- dim: int,
- *,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)[0]
- mask_input = _combine_input_and_mask(amax, input, mask)
- if mask_input.layout == torch.strided:
- return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
- else:
- raise ValueError(
- f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def log_softmax(
- input: Tensor | MaskedTensor,
- dim: int,
- *,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)[0]
- mask_input = _combine_input_and_mask(amax, input, mask)
- if mask_input.layout == torch.strided:
- return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
- else:
- raise ValueError(
- f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def softmin(
- input: Tensor | MaskedTensor,
- dim: int,
- *,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- dim_ = _canonical_dim(dim, input.ndim)[0]
- mask_input = _combine_input_and_mask(amin, input, mask)
- if mask_input.layout == torch.strided:
- return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
- else:
- raise ValueError(
- f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
- )
- @_apply_docstring_templates
- def normalize(
- input: Tensor | MaskedTensor,
- ord: float,
- dim: int,
- *,
- eps: float = 1e-12,
- dtype: DType | None = None,
- mask: Tensor | None = None,
- ) -> Tensor:
- if dtype is None:
- dtype = input.dtype
- # TODO: eliminate mask_input as unnecessary when using masked divide.
- mask_input = _combine_input_and_mask(sum, input, mask)
- if mask_input.layout == torch.strided:
- nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
- # TODO: replace torch.maximum with masked maximum when available.
- denom = torch.maximum(nrm_, nrm_.new_full([], eps))
- # TODO: replace torch.divide with masked divide when available.
- return torch.divide(mask_input, denom)
- else:
- raise ValueError(
- f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
- )
|