| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import warnings
- import torch
- from .core import is_masked_tensor
- from .creation import as_masked_tensor, masked_tensor
- __all__ = [] # type: ignore[var-annotated]
- def _masked_all_all(data, mask=None):
- if mask is None:
- return data.all()
- return data.masked_fill(~mask, True).all()
- def _masked_all_dim(data, dim, keepdim=False, mask=None):
- if mask is None:
- return torch.all(data, dim=dim, keepdim=keepdim)
- return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
- def _masked_all(*args, **kwargs):
- if len(args) == 1 and len(kwargs) == 1:
- return _masked_all_all(args[0], mask=kwargs["mask"])
- return _masked_all_dim(*args, **kwargs)
- def _multidim_any(mask, dim, keepdim):
- if isinstance(dim, int):
- return _multidim_any(mask, [dim], keepdim)
- for d in sorted(dim, reverse=True):
- mask = torch.any(mask, dim=d, keepdim=keepdim)
- return mask
- def _get_masked_fn(fn):
- if fn == "all":
- return _masked_all
- return getattr(torch.masked, fn)
- def _torch_reduce_all(fn):
- def reduce_all(self):
- masked_fn = _get_masked_fn(fn)
- data = self.get_data()
- mask = self.get_mask().values() if self.is_sparse else self.get_mask()
- # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
- # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
- # Therefore, this implementation calculates it using the strides.
- if fn == "all":
- result_data = masked_fn(data, mask=mask)
- elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
- sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
- indices = (
- data.to_sparse_coo().indices()
- if not self.is_sparse_coo()
- else data.indices()
- )
- idx = indices.unbind(1)[sparse_idx]
- stride = data.size().numel() / torch.tensor(
- data.size(), device=data.device
- ).cumprod(0)
- result_data = torch.sum(idx * stride)
- # we simply pass in the values for sparse COO/CSR tensors
- elif self.is_sparse:
- result_data = masked_fn(masked_tensor(data.values(), mask))
- else:
- result_data = masked_fn(self, mask=mask)
- return as_masked_tensor(result_data, torch.any(mask))
- return reduce_all
- def _torch_reduce_dim(fn):
- def reduce_dim(self, dim, keepdim=False, dtype=None):
- if self.is_sparse:
- msg = (
- f"The sparse version of {fn} is not implemented in reductions.\n"
- "If you would like this operator to be supported, please file an issue for a feature request at "
- "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
- "In the case that the semantics for the operator are not trivial, it would be appreciated "
- "to also include a proposal for the semantics."
- )
- warnings.warn(msg, stacklevel=2)
- return NotImplemented
- if not is_masked_tensor(self):
- raise TypeError("Input to reduce_dim must be a MaskedTensor")
- masked_fn = _get_masked_fn(fn)
- data = self.get_data()
- mask = self.get_mask()
- if fn == "all":
- result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
- else:
- result_data = masked_fn(
- self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
- )
- return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
- return reduce_dim
- def _torch_reduce(fn):
- def reduce_fn(*args, **kwargs):
- if len(args) == 1 and len(kwargs) == 0:
- return _torch_reduce_all(fn)(args[0])
- return _torch_reduce_dim(fn)(*args, **kwargs)
- return reduce_fn
- def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
- return input, dim, keepdim, dtype
- def _torch_grad_reduce(fn):
- def grad_reduce(*args, **kwargs):
- if len(args) == 1 and len(kwargs) == 0:
- return _torch_reduce_all(fn)(args[0])
- # TODO: autograd.Function doesn't support kwarg
- input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
- return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
- return grad_reduce
- REDUCE_NAMES = [
- "sum",
- "mean",
- "amin",
- "amax",
- "argmin",
- "argmax",
- "prod",
- "all",
- "norm",
- "var",
- "std",
- ]
- NATIVE_REDUCE_MAP = {
- getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
- }
- TORCH_REDUCE_MAP = {
- getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
- }
- TENSOR_REDUCE_MAP = {
- getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
- }
- NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
- TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
- TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
- def _is_reduction(fn):
- return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
- def _apply_reduction(fn, *args, **kwargs):
- if fn in NATIVE_REDUCE_MAP:
- return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
- if fn in TORCH_REDUCE_MAP:
- return TORCH_REDUCE_MAP[fn](*args, **kwargs)
- if fn in TENSOR_REDUCE_MAP:
- return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
- return NotImplemented
|