TensorSubclassLikeUtils.h 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/List.h>
  4. #include <ATen/core/Tensor.h>
  5. #include <c10/core/impl/TorchDispatchModeTLS.h>
  6. #ifndef AT_PER_OPERATOR_HEADERS
  7. #include <ATen/Functions.h>
  8. #else
  9. #include <ATen/ops/equal.h>
  10. #endif
  11. namespace at {
  12. // Note [Tensor-subclass-like Tensors]
  13. // Tensor-subclass-like is defined as:
  14. // - a Tensor subclass (via __torch_dispatch__ in Python or extending
  15. // TensorImpl in C++)
  16. // - anything else that shares the same perils as Tensor subclasses.
  17. // For example, many Tensor subclasses do not have storage and meta Tensors
  18. // do not have storage either, so meta Tensors belong here.
  19. //
  20. // We should ensure that PyTorch internals supports Tensor-subclass-like
  21. // objects. In particular, Tensor-subclass-like objects struggle with two
  22. // classes of operations that are problematic for Tensor subclasses:
  23. // 1. Because some Tensor subclasses do not have storage, .item() or
  24. // .data_ptr() calls are not good.
  25. // 2. Certain in-place operations can eliminate the typing of the Tensor
  26. // subclass. For example:
  27. // >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
  28. // If input is a Tensor subclass, then the above ends up either erroring out
  29. // or returning a regular non-Tensor-subclass Tensor!
  30. constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
  31. {DispatchKey::FuncTorchGradWrapper,
  32. DispatchKey::FuncTorchBatched,
  33. DispatchKey::Functionalize});
  34. constexpr auto kTensorSubclassLike =
  35. kFunctorchWrappedTensors |
  36. DispatchKeySet(
  37. {// WARNING: DO NOT put combined backend component + functionality keys
  38. // here, you will incorrectly always match on the functionality key
  39. // no matter the backend component
  40. DispatchKey::Batched,
  41. DispatchKey::Sparse,
  42. DispatchKey::SparseCsr,
  43. DispatchKey::Python}) |
  44. DispatchKeySet(BackendComponent::MetaBit);
  45. inline bool isTensorSubclassLike(const Tensor& tensor) {
  46. if (c10::impl::dispatch_mode_enabled())
  47. return true;
  48. auto key_set = tensor.unsafeGetTensorImpl()->key_set();
  49. return !(key_set & kTensorSubclassLike).empty();
  50. }
  51. inline bool areAnyTensorSubclassLike(TensorList tensors) {
  52. if (c10::impl::dispatch_mode_enabled())
  53. return true;
  54. return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
  55. }
  56. inline bool areAnyOptionalTensorSubclassLike(
  57. const c10::List<std::optional<Tensor>>& tensors) {
  58. if (c10::impl::dispatch_mode_enabled())
  59. return true;
  60. return std::any_of(
  61. tensors.begin(),
  62. tensors.end(),
  63. [](const std::optional<Tensor>& opt_tensor) {
  64. return (
  65. opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
  66. });
  67. }
  68. // Helper function to deal testing truthfulness of a scalar tensor
  69. // in a Composite Compliant manner.
  70. // NOTE: This function expects a scalar tensor of boolean dtype.
  71. // Eg.
  72. // Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
  73. // Composite Compliant Pattern : is_salar_tensor_true((t == 0).all())
  74. inline bool is_scalar_tensor_true(const Tensor& t) {
  75. TORCH_INTERNAL_ASSERT(t.dim() == 0)
  76. TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
  77. return at::equal(t, t.new_ones({}, t.options()));
  78. }
  79. } // namespace at
  80. #else
  81. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  82. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)