reductions.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import warnings
  4. import torch
  5. from .core import is_masked_tensor
  6. from .creation import as_masked_tensor, masked_tensor
  7. __all__ = [] # type: ignore[var-annotated]
  8. def _masked_all_all(data, mask=None):
  9. if mask is None:
  10. return data.all()
  11. return data.masked_fill(~mask, True).all()
  12. def _masked_all_dim(data, dim, keepdim=False, mask=None):
  13. if mask is None:
  14. return torch.all(data, dim=dim, keepdim=keepdim)
  15. return torch.all(data.masked_fill(~mask, True), dim=dim, keepdim=keepdim)
  16. def _masked_all(*args, **kwargs):
  17. if len(args) == 1 and len(kwargs) == 1:
  18. return _masked_all_all(args[0], mask=kwargs["mask"])
  19. return _masked_all_dim(*args, **kwargs)
  20. def _multidim_any(mask, dim, keepdim):
  21. if isinstance(dim, int):
  22. return _multidim_any(mask, [dim], keepdim)
  23. for d in sorted(dim, reverse=True):
  24. mask = torch.any(mask, dim=d, keepdim=keepdim)
  25. return mask
  26. def _get_masked_fn(fn):
  27. if fn == "all":
  28. return _masked_all
  29. return getattr(torch.masked, fn)
  30. def _torch_reduce_all(fn):
  31. def reduce_all(self):
  32. masked_fn = _get_masked_fn(fn)
  33. data = self.get_data()
  34. mask = self.get_mask().values() if self.is_sparse else self.get_mask()
  35. # When reduction is "all", then torch.argmin/torch.argmax needs to return the index of the
  36. # element corresponding to the min/max, but this operation isn't supported correctly for sparse layouts.
  37. # Therefore, this implementation calculates it using the strides.
  38. if fn == "all":
  39. result_data = masked_fn(data, mask=mask)
  40. elif fn in {"argmin", "argmax"} and self.is_sparse_coo():
  41. sparse_idx = masked_fn(data.values(), mask=mask).to(dtype=torch.int)
  42. indices = (
  43. data.to_sparse_coo().indices()
  44. if not self.is_sparse_coo()
  45. else data.indices()
  46. )
  47. idx = indices.unbind(1)[sparse_idx]
  48. stride = data.size().numel() / torch.tensor(
  49. data.size(), device=data.device
  50. ).cumprod(0)
  51. result_data = torch.sum(idx * stride)
  52. # we simply pass in the values for sparse COO/CSR tensors
  53. elif self.is_sparse:
  54. result_data = masked_fn(masked_tensor(data.values(), mask))
  55. else:
  56. result_data = masked_fn(self, mask=mask)
  57. return as_masked_tensor(result_data, torch.any(mask))
  58. return reduce_all
  59. def _torch_reduce_dim(fn):
  60. def reduce_dim(self, dim, keepdim=False, dtype=None):
  61. if self.is_sparse:
  62. msg = (
  63. f"The sparse version of {fn} is not implemented in reductions.\n"
  64. "If you would like this operator to be supported, please file an issue for a feature request at "
  65. "https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.\n"
  66. "In the case that the semantics for the operator are not trivial, it would be appreciated "
  67. "to also include a proposal for the semantics."
  68. )
  69. warnings.warn(msg, stacklevel=2)
  70. return NotImplemented
  71. if not is_masked_tensor(self):
  72. raise TypeError("Input to reduce_dim must be a MaskedTensor")
  73. masked_fn = _get_masked_fn(fn)
  74. data = self.get_data()
  75. mask = self.get_mask()
  76. if fn == "all":
  77. result_data = masked_fn(data, dim=dim, keepdim=keepdim, mask=mask)
  78. else:
  79. result_data = masked_fn(
  80. self, dim=dim, keepdim=keepdim, dtype=dtype, mask=self.get_mask()
  81. )
  82. return as_masked_tensor(result_data, _multidim_any(mask, dim, keepdim))
  83. return reduce_dim
  84. def _torch_reduce(fn):
  85. def reduce_fn(*args, **kwargs):
  86. if len(args) == 1 and len(kwargs) == 0:
  87. return _torch_reduce_all(fn)(args[0])
  88. return _torch_reduce_dim(fn)(*args, **kwargs)
  89. return reduce_fn
  90. def _reduce_dim_args(input, dim, keepdim=False, dtype=None):
  91. return input, dim, keepdim, dtype
  92. def _torch_grad_reduce(fn):
  93. def grad_reduce(*args, **kwargs):
  94. if len(args) == 1 and len(kwargs) == 0:
  95. return _torch_reduce_all(fn)(args[0])
  96. # TODO: autograd.Function doesn't support kwarg
  97. input, dim, keepdim, dtype = _reduce_dim_args(*args, **kwargs)
  98. return _torch_reduce_dim(fn)(input, dim, keepdim, dtype)
  99. return grad_reduce
  100. REDUCE_NAMES = [
  101. "sum",
  102. "mean",
  103. "amin",
  104. "amax",
  105. "argmin",
  106. "argmax",
  107. "prod",
  108. "all",
  109. "norm",
  110. "var",
  111. "std",
  112. ]
  113. NATIVE_REDUCE_MAP = {
  114. getattr(torch.ops.aten, name): _torch_reduce(name) for name in REDUCE_NAMES
  115. }
  116. TORCH_REDUCE_MAP = {
  117. getattr(torch, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
  118. }
  119. TENSOR_REDUCE_MAP = {
  120. getattr(torch.Tensor, name): _torch_grad_reduce(name) for name in REDUCE_NAMES
  121. }
  122. NATIVE_REDUCE_FNS = list(NATIVE_REDUCE_MAP.keys())
  123. TORCH_REDUCE_FNS = list(TORCH_REDUCE_MAP.keys())
  124. TENSOR_REDUCE_FNS = list(TENSOR_REDUCE_MAP.keys())
  125. def _is_reduction(fn):
  126. return fn in NATIVE_REDUCE_MAP or fn in TORCH_REDUCE_MAP or fn in TENSOR_REDUCE_MAP
  127. def _apply_reduction(fn, *args, **kwargs):
  128. if fn in NATIVE_REDUCE_MAP:
  129. return NATIVE_REDUCE_MAP[fn](*args, **kwargs)
  130. if fn in TORCH_REDUCE_MAP:
  131. return TORCH_REDUCE_MAP[fn](*args, **kwargs)
  132. if fn in TENSOR_REDUCE_MAP:
  133. return TENSOR_REDUCE_MAP[fn](*args, **kwargs)
  134. return NotImplemented