binary.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import torch
  4. from .core import (
  5. _map_mt_args_kwargs,
  6. _masks_match,
  7. _tensors_match,
  8. _wrap_result,
  9. is_masked_tensor,
  10. )
  11. __all__ = [] # type: ignore[var-annotated]
  12. BINARY_NAMES = [
  13. "add",
  14. "atan2",
  15. "arctan2",
  16. "bitwise_and",
  17. "bitwise_or",
  18. "bitwise_xor",
  19. "bitwise_left_shift",
  20. "bitwise_right_shift",
  21. "div",
  22. "divide",
  23. "floor_divide",
  24. "fmod",
  25. "logaddexp",
  26. "logaddexp2",
  27. "mul",
  28. "multiply",
  29. "nextafter",
  30. "remainder",
  31. "sub",
  32. "subtract",
  33. "true_divide",
  34. "eq",
  35. "ne",
  36. "le",
  37. "ge",
  38. "greater",
  39. "greater_equal",
  40. "gt",
  41. "less_equal",
  42. "lt",
  43. "less",
  44. "maximum",
  45. "minimum",
  46. "fmax",
  47. "fmin",
  48. "not_equal",
  49. ]
  50. INPLACE_BINARY_NAMES = [
  51. n + "_"
  52. for n in (
  53. list(
  54. set(BINARY_NAMES)
  55. - {
  56. "logaddexp",
  57. "logaddexp2",
  58. "equal",
  59. "fmin",
  60. "minimum",
  61. "maximum",
  62. "fmax",
  63. }
  64. )
  65. )
  66. ]
  67. def _get_at_least_one_mask(a, b):
  68. if not is_masked_tensor(a) and not is_masked_tensor(b):
  69. raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
  70. if not _masks_match(a, b):
  71. raise ValueError("a and b must have matching masks")
  72. if is_masked_tensor(a):
  73. return a.get_mask()
  74. return b.get_mask()
  75. def _binary_helper(fn, args, kwargs, inplace):
  76. if len(kwargs) != 0:
  77. raise ValueError("len(kwargs) must equal 0")
  78. for a in args[2:]:
  79. if torch.is_tensor(a):
  80. raise TypeError(
  81. "MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs"
  82. )
  83. if not _masks_match(*args[:2]):
  84. raise ValueError(
  85. "Input masks must match. If you need support for this, please open an issue on Github."
  86. )
  87. data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())
  88. mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())
  89. args0_layout = data_args[0].layout
  90. same_layout = (
  91. torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])
  92. ) and (args0_layout == data_args[1].layout)
  93. if args0_layout == torch.sparse_coo:
  94. if same_layout:
  95. if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
  96. raise ValueError(
  97. "sparse_coo indices must match. If you need support for this, please open an issue on Github."
  98. )
  99. if data_args[0].size() != data_args[1].size():
  100. raise ValueError(
  101. "input1 and input2 must have the same size for binary functions."
  102. )
  103. data_args[1] = data_args[1].values()
  104. i = data_args[0].indices()
  105. size = data_args[0].size()
  106. data_args[0] = data_args[0].values()
  107. v = fn(*data_args)
  108. result_data = torch.sparse_coo_tensor(i, v, size)
  109. elif args0_layout == torch.sparse_csr:
  110. if same_layout:
  111. if not (
  112. _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
  113. and _tensors_match(
  114. data_args[0].col_indices(), data_args[1].col_indices()
  115. )
  116. ):
  117. raise ValueError(
  118. "sparse_csr indices must match. If you need support for this, please open an issue on Github."
  119. )
  120. data_args[1] = data_args[1].values()
  121. crow = data_args[0].crow_indices()
  122. col = data_args[0].col_indices()
  123. size = data_args[0].size()
  124. data_args[0] = data_args[0].values()
  125. v = fn(*data_args)
  126. result_data = torch.sparse_csr_tensor(crow, col, v, size)
  127. else:
  128. result_data = fn(*data_args)
  129. if inplace:
  130. args[0]._set_data_mask(result_data, mask_args[0])
  131. return args[0]
  132. else:
  133. result_mask = _get_at_least_one_mask(*args[:2])
  134. # sparse tensors don't have strides so we can only expand if the layout is strided
  135. if args0_layout == torch.strided:
  136. result_mask = result_mask.expand_as(result_data)
  137. return _wrap_result(result_data, result_mask)
  138. def _torch_binary(fn_name):
  139. fn = getattr(torch.ops.aten, fn_name)
  140. def binary_fn(*args, **kwargs):
  141. return _binary_helper(fn, args, kwargs, inplace=False)
  142. return binary_fn
  143. def _torch_inplace_binary(fn_name):
  144. fn = getattr(torch.ops.aten, fn_name)
  145. def binary_fn(*args, **kwargs):
  146. return _binary_helper(fn, args, kwargs, inplace=True)
  147. return binary_fn
  148. NATIVE_BINARY_MAP = {
  149. getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
  150. }
  151. NATIVE_INPLACE_BINARY_MAP = {
  152. getattr(torch.ops.aten, name): _torch_inplace_binary(name)
  153. for name in INPLACE_BINARY_NAMES
  154. }
  155. NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
  156. NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
  157. def _is_native_binary(fn):
  158. return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
  159. def _apply_native_binary(fn, *args, **kwargs):
  160. if fn in NATIVE_BINARY_FNS:
  161. return NATIVE_BINARY_MAP[fn](*args, **kwargs)
  162. if fn in NATIVE_INPLACE_BINARY_FNS:
  163. return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
  164. return NotImplemented