unary.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import torch
  4. from .core import _map_mt_args_kwargs, _wrap_result
  5. __all__ = [] # type: ignore[var-annotated]
  6. UNARY_NAMES = [
  7. "abs",
  8. "absolute",
  9. "acos",
  10. "arccos",
  11. "acosh",
  12. "arccosh",
  13. "angle",
  14. "asin",
  15. "arcsin",
  16. "asinh",
  17. "arcsinh",
  18. "atan",
  19. "arctan",
  20. "atanh",
  21. "arctanh",
  22. "bitwise_not",
  23. "ceil",
  24. "clamp",
  25. "clip",
  26. "conj_physical",
  27. "cos",
  28. "cosh",
  29. "deg2rad",
  30. "digamma",
  31. "erf",
  32. "erfc",
  33. "erfinv",
  34. "exp",
  35. "exp2",
  36. "expm1",
  37. "fix",
  38. "floor",
  39. "frac",
  40. "lgamma",
  41. "log",
  42. "log10",
  43. "log1p",
  44. "log2",
  45. "logit",
  46. "i0",
  47. "isnan",
  48. "nan_to_num",
  49. "neg",
  50. "negative",
  51. "positive",
  52. "pow",
  53. "rad2deg",
  54. "reciprocal",
  55. "round",
  56. "rsqrt",
  57. "sigmoid",
  58. "sign",
  59. "sgn",
  60. "signbit",
  61. "sin",
  62. "sinc",
  63. "sinh",
  64. "sqrt",
  65. "square",
  66. "tan",
  67. "tanh",
  68. "trunc",
  69. ]
  70. INPLACE_UNARY_NAMES = [
  71. n + "_"
  72. for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
  73. ]
  74. # Explicitly tracking functions we know are currently not supported
  75. # This might be due to missing code gen or because of complex semantics
  76. UNARY_NAMES_UNSUPPORTED = [
  77. "atan2",
  78. "arctan2",
  79. "bitwise_left_shift",
  80. "bitwise_right_shift",
  81. "copysign",
  82. "float_power",
  83. "fmod",
  84. "frexp",
  85. "gradient",
  86. "imag",
  87. "ldexp",
  88. "lerp",
  89. "logical_not",
  90. "hypot",
  91. "igamma",
  92. "igammac",
  93. "mvlgamma",
  94. "nextafter",
  95. "polygamma",
  96. "real",
  97. "remainder",
  98. "true_divide",
  99. "xlogy",
  100. ]
  101. def _unary_helper(fn, args, kwargs, inplace):
  102. if len(kwargs) != 0:
  103. raise ValueError(
  104. "MaskedTensor unary ops require that len(kwargs) == 0. "
  105. "If you need support for this, please open an issue on Github."
  106. )
  107. for a in args[1:]:
  108. if torch.is_tensor(a):
  109. raise TypeError(
  110. "MaskedTensor unary ops do not support additional Tensor arguments"
  111. )
  112. mask_args, _mask_kwargs = _map_mt_args_kwargs(
  113. args, kwargs, lambda x: x._masked_mask
  114. )
  115. data_args, _data_kwargs = _map_mt_args_kwargs(
  116. args, kwargs, lambda x: x._masked_data
  117. )
  118. if args[0].layout == torch.sparse_coo:
  119. data_args[0] = data_args[0].coalesce()
  120. s = data_args[0].size()
  121. i = data_args[0].indices()
  122. data_args[0] = data_args[0].coalesce().values()
  123. v = fn(*data_args)
  124. result_data = torch.sparse_coo_tensor(i, v, size=s)
  125. elif args[0].layout == torch.sparse_csr:
  126. crow = data_args[0].crow_indices()
  127. col = data_args[0].col_indices()
  128. data_args[0] = data_args[0].values()
  129. v = fn(*data_args)
  130. result_data = torch.sparse_csr_tensor(crow, col, v)
  131. else:
  132. result_data = fn(*data_args)
  133. if inplace:
  134. args[0]._set_data_mask(result_data, mask_args[0])
  135. return args[0]
  136. else:
  137. return _wrap_result(result_data, mask_args[0])
  138. def _torch_unary(fn_name):
  139. fn = getattr(torch.ops.aten, fn_name)
  140. def unary_fn(*args, **kwargs):
  141. return _unary_helper(fn, args, kwargs, inplace=False)
  142. return unary_fn
  143. def _torch_inplace_unary(fn_name):
  144. fn = getattr(torch.ops.aten, fn_name)
  145. def unary_fn(*args, **kwargs):
  146. return _unary_helper(fn, args, kwargs, inplace=True)
  147. return unary_fn
  148. NATIVE_UNARY_MAP = {
  149. getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
  150. }
  151. NATIVE_INPLACE_UNARY_MAP = {
  152. getattr(torch.ops.aten, name): _torch_inplace_unary(name)
  153. for name in INPLACE_UNARY_NAMES
  154. }
  155. NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
  156. NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
  157. def _is_native_unary(fn):
  158. return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
  159. def _apply_native_unary(fn, *args, **kwargs):
  160. if fn in NATIVE_UNARY_FNS:
  161. return NATIVE_UNARY_MAP[fn](*args, **kwargs)
  162. if fn in NATIVE_INPLACE_UNARY_FNS:
  163. return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
  164. return NotImplemented